diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 39eac793249..00000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,247 +0,0 @@ -version: 2.1 - -commands: - checkout_merge: - description: "checkout merge branch" - steps: - - checkout - designate_upload_channel: - description: "inserts the correct upload channel into ${BASH_ENV}" - steps: - - run: - name: adding UPLOAD_CHANNEL to BASH_ENV - command: | - our_upload_channel=nightly - # On tags upload to test instead - if [[ -n "${CIRCLE_TAG}" ]] || [[ ${CIRCLE_BRANCH} =~ release/* ]]; then - our_upload_channel=test - fi - echo "export UPLOAD_CHANNEL=${our_upload_channel}" >> ${BASH_ENV} - apt_install: - parameters: - args: - type: string - descr: - type: string - default: "" - update: - type: boolean - default: true - steps: - - run: - name: > - <<^ parameters.descr >> apt install << parameters.args >> <> - <<# parameters.descr >> << parameters.descr >> <> - command: | - <<# parameters.update >> sudo apt update -qy <> - sudo apt install << parameters.args >> - pip_install: - parameters: - args: - type: string - descr: - type: string - default: "" - user: - type: boolean - default: true - steps: - - run: - name: > - <<^ parameters.descr >> pip install << parameters.args >> <> - <<# parameters.descr >> << parameters.descr >> <> - command: > - pip install - <<# parameters.user >> --user <> - --progress-bar=off - << parameters.args >> - - install_torchrl: - parameters: - editable: - type: boolean - default: true - steps: - - pip_install: - args: --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - descr: Install PyTorch from nightly releases - - pip_install: - args: --no-build-isolation <<# parameters.editable >> --editable <> . - descr: Install torchrl <<# parameters.editable >> in editable mode <> - - -binary_common: &binary_common - parameters: - # Edit these defaults to do a release - build_version: - description: "version number of release binary; by default, build a nightly" - type: string - default: "" - pytorch_version: - description: "PyTorch version to build against; by default, use a nightly" - type: string - default: "" - # Don't edit these - python_version: - description: "Python version to build against (e.g., 3.7)" - type: string - cu_version: - description: "CUDA version to build against, in CU format (e.g., cpu or cu100)" - type: string - default: "cpu" - unicode_abi: - description: "Python 2.7 wheel only: whether or not we are cp27mu (default: no)" - type: string - default: "" - conda_docker_image: - description: "Conda only: what docker image to use" - type: string - default: "pytorch/conda-builder:cpu" - environment: - PYTHON_VERSION: << parameters.python_version >> - PYTORCH_VERSION: << parameters.pytorch_version >> - UNICODE_ABI: << parameters.unicode_abi >> - CU_VERSION: << parameters.cu_version >> - -smoke_test_common: &smoke_test_common - <<: *binary_common - docker: - - image: torchrl/smoke_test:latest - -jobs: - type_check_python: - docker: - - image: circleci/python:3.7 - steps: - - checkout - - pip_install: - args: cmake ninja - descr: Install CMake and Ninja - - install_torchrl: - editable: true - - pip_install: - args: mypy - descr: Install Python type check utilities - - run: - name: Check Python types statically - command: mypy --install-types --non-interactive --config-file mypy.ini - - unittest_linux_envpool_gpu: - <<: *binary_common - machine: - image: ubuntu-2004-cuda-11.4:202110-01 - resource_class: gpu.nvidia.medium - environment: - image_name: "pytorch/manylinux-cuda117" - TAR_OPTIONS: --no-same-owner - PYTHON_VERSION: << parameters.python_version >> - CU_VERSION: << parameters.cu_version >> - - steps: - - checkout - - designate_upload_channel - - run: - name: Generate cache key - # This will refresh cache on Sundays, nightly build should generate new cache. - command: echo "$(date +"%Y-%U")" > .circleci-weekly - - restore_cache: - keys: - - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_envpool/environment.yml" }}-{{ checksum ".circleci-weekly" }} - - run: - name: Setup - command: .circleci/unittest/linux_libs/scripts_envpool/setup_env.sh - - save_cache: - key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_envpool/environment.yml" }}-{{ checksum ".circleci-weekly" }} - paths: - - conda - - env - - run: - name: Install torchrl - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_libs/scripts_envpool/install.sh - - run: - name: Run tests - command: bash .circleci/unittest/linux_libs/scripts_envpool/run_test.sh - - run: - name: Codecov upload - command: | - curl -Os https://uploader.codecov.io/latest/linux/codecov - chmod +x codecov - ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-envpool - - run: - name: Post Process - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_envpool/post_process.sh - - store_test_results: - path: test-results - - unittest_linux_stable_gpu: - <<: *binary_common - machine: - image: ubuntu-2004-cuda-11.4:202110-01 - resource_class: gpu.nvidia.medium - environment: - image_name: "pytorch/manylinux-cuda116" - TAR_OPTIONS: --no-same-owner - PYTHON_VERSION: << parameters.python_version >> - CU_VERSION: << parameters.cu_version >> - - steps: - - checkout - - designate_upload_channel - - run: - name: Generate cache key - # This will refresh cache on Sundays, nightly build should generate new cache. - command: echo "$(date +"%Y-%U")" > .circleci-weekly - - restore_cache: - - keys: - - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} - - - run: - name: Setup - command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/setup_env.sh - - save_cache: - - key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} - - paths: - - conda - - env -# - run: -# # Here we create an envlist file that contains some env variables that we want the docker container to be aware of. -# # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables. -# # They're available in all the other workflows (OSX and Windows). -# # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible. -# # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run". -# name: export CIRCLECI env var -# command: echo "CIRCLECI=true" >> ./env.list - - run: - name: Install torchrl -# command: bash .circleci/unittest/linux_stable/scripts/install.sh - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_stable/scripts/install.sh - - run: - name: Run tests - command: bash .circleci/unittest/linux_stable/scripts/run_test.sh -# command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh - - run: - name: Codecov upload - command: | - curl -Os https://uploader.codecov.io/latest/linux/codecov - chmod +x codecov - ./codecov -t ${CODECOV_TOKEN} -s ./ -Z -F linux-stable-gpu - - run: - name: Post Process - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/post_process.sh - - store_test_results: - path: test-results - -workflows: - unittest: - jobs: - - unittest_linux_envpool_gpu: - cu_version: cu117 - name: unittest_linux_envpool_gpu_py3.8 - python_version: '3.8' - - unittest_linux_stable_gpu: - cu_version: cu116 - name: unittest_linux_stable_gpu_py3.9 - python_version: '3.9' diff --git a/.circleci/docs/setup_env.sh b/.circleci/docs/setup_env.sh deleted file mode 100755 index 496e57b29bd..00000000000 --- a/.circleci/docs/setup_env.sh +++ /dev/null @@ -1,66 +0,0 @@ -#apt-get update -y -#apt-get install software-properties-common -y -#add-apt-repository ppa:git-core/candidate -y -#apt-get update -y -#apt-get upgrade -y -#apt-get -y install libglfw3 libglew2.0 gcc curl g++ unzip \ -# wget sudo git cmake libz-dev \ -# zlib1g-dev python3.8 python3-pip ninja - -#yum install -y mesa-libGL freeglut egl-utils glew glfw -#yum install -y glew glfw -apt-get update && apt-get install -y git wget gcc g++ - -root_dir="$(pwd)" -conda_dir="${root_dir}/conda" -env_dir="${root_dir}/env" - -os=Linux - -# 1. Install conda at ./conda -printf "* Installing conda\n" -wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" -bash ./miniconda.sh -b -f -p "${conda_dir}" - -eval "$(${conda_dir}/bin/conda shell.bash hook)" - -printf "* Creating a test environment\n" -conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" - -printf "* Activating\n" -conda activate "${env_dir}" - -conda install -c conda-forge zlib -y - -pip3 install --upgrade pip --quiet --root-user-action=ignore - -printf "python version\n" -python --version - -pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu118 --quiet --root-user-action=ignore -#pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu --quiet --root-user-action=ignore - -printf "Installing tensordict\n" -pip3 install git+https://github.com/pytorch-labs/tensordict.git --quiet --root-user-action=ignore - -printf "Installing torchrl\n" -pip3 install -e . --quiet --root-user-action=ignore - -printf "Installing requirements\n" -pip3 install -r docs/requirements.txt --quiet --root-user-action=ignore -printf "Installed all dependencies\n" - -printf "smoke test\n" -PYOPENGL_PLATFORM=egl MUJOCO_GL=egl python3 -c """from torchrl.envs.libs.dm_control import DMControlEnv -print(DMControlEnv('cheetah', 'run').reset()) -""" - -printf "building docs...\n" -cd ./docs -#timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build SPHINXOPTS=-v ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi -PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build -cd .. -printf "done!\n" - -git clone --branch gh-pages https://github.com/pytorch-labs/tensordict.git docs/_local_build/tensordict -rm -rf docs/_local_build/tensordict/.git diff --git a/.circleci/unittest/linux_distributed/scripts/setup_env.sh b/.circleci/unittest/linux_distributed/scripts/setup_env.sh deleted file mode 100755 index 5d17d3eaec6..00000000000 --- a/.circleci/unittest/linux_distributed/scripts/setup_env.sh +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env bash - -# This script is for setting up environment in which unit test is ran. -# To speed up the CI time, the resulting environment is cached. -# -# Do not install PyTorch and torchvision here, otherwise they also get cached. - -set -e - -this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# Avoid error: "fatal: unsafe repository" -git config --global --add safe.directory '*' -root_dir="$(git rev-parse --show-toplevel)" -conda_dir="${root_dir}/conda" -env_dir="${root_dir}/env" -lib_dir="${env_dir}/lib" - -cd "${root_dir}" - -case "$(uname -s)" in - Darwin*) os=MacOSX;; - *) os=Linux -esac - -# 1. Install conda at ./conda -if [ ! -d "${conda_dir}" ]; then - printf "* Installing conda\n" - wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" - bash ./miniconda.sh -b -f -p "${conda_dir}" -fi -eval "$(${conda_dir}/bin/conda shell.bash hook)" - -# 2. Create test environment at ./env -printf "python: ${PYTHON_VERSION}\n" -if [ ! -d "${env_dir}" ]; then - printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" -fi -conda activate "${env_dir}" - -## 3. Install mujoco -#printf "* Installing mujoco and related\n" -#mkdir -p $root_dir/.mujoco -#cd $root_dir/.mujoco/ -#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz -#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz -#wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -#tar -xf mujoco210-linux-x86_64.tar.gz -#cd $this_dir - -# 4. Install Conda dependencies -printf "* Installing dependencies (except PyTorch)\n" -echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" -cat "${this_dir}/environment.yml" - - -if [[ $OSTYPE == 'darwin'* ]]; then - PRIVATE_MUJOCO_GL=glfw -elif [ "${CU_VERSION:-}" == cpu ]; then - PRIVATE_MUJOCO_GL=osmesa -else - PRIVATE_MUJOCO_GL=egl -fi - -export MUJOCO_GL=$PRIVATE_MUJOCO_GL -conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ - DISPLAY=unix:0.0 \ - MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ - LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ - SDL_VIDEODRIVER=dummy \ - MUJOCO_GL=$PRIVATE_MUJOCO_GL \ - PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL - -# Software rendering requires GLX and OSMesa. -if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then - yum makecache - yum install -y glfw - yum install -y glew - yum install -y mesa-libGL - yum install -y mesa-libGL-devel - yum install -y mesa-libOSMesa-devel - yum -y install egl-utils - yum -y install freeglut -fi - -pip install pip --upgrade - -conda env update --file "${this_dir}/environment.yml" --prune - -conda deactivate -conda activate "${env_dir}" - -if [[ $OSTYPE != 'darwin'* ]]; then - # install ale-py: manylinux names are broken for CentOS so we need to manually download and - # rename them - PY_VERSION=$(python --version) - echo "installing ale-py for ${PY_PY_VERSION}" - if [[ $PY_VERSION == *"3.7"* ]]; then - wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.8"* ]]; then - wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.9"* ]]; then - wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.10"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - fi - echo "installing gymnasium" - pip install "gymnasium[atari,accept-rom-license]" -else - pip install "gymnasium[atari,accept-rom-license]" -fi diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh deleted file mode 100755 index a9f361f49c0..00000000000 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ /dev/null @@ -1,253 +0,0 @@ -#!/usr/bin/env bash - -# Leave blank as code needs to start on line 29 for run_local.sh -# -# -# -# -# -# -# - -set -e -set -v - -export PYTORCH_TEST_WITH_SLOW='1' -python -m torch.utils.collect_env -# Avoid error: "fatal: unsafe repository" -git config --global --add safe.directory '*' - -root_dir="$(git rev-parse --show-toplevel)" -env_dir="${root_dir}/env" -lib_dir="${env_dir}/lib" - -# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir -export MKL_THREADING_LAYER=GNU - -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 - -# With batched environments -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ - env.num_envs=1 \ - env.device=cuda:0 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - optim.device=cuda:0 \ - loss.mini_batch_size=10 \ - loss.ppo_epochs=1 \ - logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - optimization.batch_size=10 \ - collector.frames_per_batch=16 \ - collector.num_workers=4 \ - collector.env_per_collector=2 \ - collector.collector_device=cuda:0 \ - network.device=cuda:0 \ - optimization.utd_ratio=1 \ - replay_buffer.size=120 \ - env.name=Pendulum-v1 \ - logger.backend= -# record_video=True \ -# record_frames=4 \ -python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ - env.num_envs=1 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False \ - optim.device=cuda:0 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=16 \ - collector.num_workers=4 \ - collector.env_per_collector=2 \ - collector.collector_device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ - replay_buffer.size=120 \ - env.name=Pendulum-v1 \ - logger.backend= -# logger.record_video=True \ -# logger.record_frames=4 \ -python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - optimization.batch_size=10 \ - collector.frames_per_batch=16 \ - collector.num_workers=4 \ - collector.env_per_collector=2 \ - collector.collector_device=cuda:0 \ - network.device=cuda:0 \ - logger.mode=offline \ - env.name=Pendulum-v1 \ - logger.backend= -python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ - total_frames=48 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - device=cuda:0 \ - mode=offline - -# With single envs -python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - optimization.batch_size=10 \ - collector.frames_per_batch=16 \ - collector.num_workers=2 \ - collector.env_per_collector=1 \ - collector.collector_device=cuda:0 \ - network.device=cuda:0 \ - optimization.utd_ratio=1 \ - replay_buffer.size=120 \ - env.name=Pendulum-v1 \ - logger.backend= -# record_video=True \ -# record_frames=4 \ -python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ - env.num_envs=1 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False \ - optim.device=cuda:0 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ - env.num_envs=1 \ - env.device=cuda:0 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - optim.device=cuda:0 \ - loss.mini_batch_size=10 \ - loss.ppo_epochs=1 \ - logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False -python .circleci/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=16 \ - collector.num_workers=2 \ - collector.env_per_collector=1 \ - collector.collector_device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ - replay_buffer.size=120 \ - env.name=Pendulum-v1 \ - logger.backend= -# record_video=True \ -# record_frames=4 \ -python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ - total_frames=48 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=2 \ - env_per_collector=1 \ - mode=offline \ - device=cuda:0 \ - collector_device=cuda:0 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - optimization.batch_size=10 \ - collector.frames_per_batch=16 \ - collector.num_workers=2 \ - collector.env_per_collector=1 \ - logger.mode=offline \ - collector.collector_device=cuda:0 \ - env.name=Pendulum-v1 \ - logger.backend= - -python .circleci/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100 - -coverage combine -coverage xml -i diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh deleted file mode 100755 index 088a7572d0b..00000000000 --- a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env bash - -set -e - -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env - -export PYTORCH_TEST_WITH_SLOW='1' -python -m torch.utils.collect_env -# Avoid error: "fatal: unsafe repository" -git config --global --add safe.directory '*' - -root_dir="$(git rev-parse --show-toplevel)" -env_dir="${root_dir}/env" -lib_dir="${env_dir}/lib" - -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir -export MKL_THREADING_LAYER=GNU - -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control' - -export DISPLAY=':99.0' -Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & -CKPT_BACKEND=torch MUJOCO_GL=egl python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py -#pytest --instafail -v --durations 200 -#python test/test_libs.py -coverage combine -coverage xml -i diff --git a/.circleci/unittest/linux_stable/scripts/install.sh b/.circleci/unittest/linux_stable/scripts/install.sh deleted file mode 100755 index 53b208092eb..00000000000 --- a/.circleci/unittest/linux_stable/scripts/install.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env bash - -unset PYTORCH_VERSION -# For unittest, nightly PyTorch is used as the following section, -# so no need to set PYTORCH_VERSION. -# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. - -set -e - -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env - -if [ "${CU_VERSION:-}" == cpu ] ; then - version="cpu" - echo "Using cpu build" -else - if [[ ${#CU_VERSION} -eq 4 ]]; then - CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" - elif [[ ${#CU_VERSION} -eq 5 ]]; then - CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" - fi - echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" - version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" -fi - -# submodules -git submodule sync && git submodule update --init --recursive - -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu -else - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -fi - -# install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git - -# smoke test -python -c "import torch;import functorch" - -printf "* Installing torchrl\n" -printf "g++ version: " -gcc --version - -python setup.py develop diff --git a/.circleci/unittest/linux_stable/scripts/run_test.sh b/.circleci/unittest/linux_stable/scripts/run_test.sh deleted file mode 100755 index 1f936c73691..00000000000 --- a/.circleci/unittest/linux_stable/scripts/run_test.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash - -set -e - -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env - -export PYTORCH_TEST_WITH_SLOW='1' -python -m torch.utils.collect_env -# Avoid error: "fatal: unsafe repository" -git config --global --add safe.directory '*' - -root_dir="$(git rev-parse --show-toplevel)" -env_dir="${root_dir}/env" -lib_dir="${env_dir}/lib" - -# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir -export MKL_THREADING_LAYER=GNU -export CKPT_BACKEND=torch - -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py -coverage combine -coverage xml -i diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh new file mode 100644 index 00000000000..2df580b5801 --- /dev/null +++ b/.github/scripts/m1_script.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +export BUILD_VERSION=0.2.0 diff --git a/.github/scripts/pre_build_script_m1.sh b/.github/scripts/pre_build_script_m1.sh new file mode 100644 index 00000000000..8f98c05c9a8 --- /dev/null +++ b/.github/scripts/pre_build_script_m1.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python3 -mpip install git+https://github.com/pytorch/tensordict.git diff --git a/.circleci/unittest/helpers/coverage_run_parallel.py b/.github/unittest/helpers/coverage_run_parallel.py similarity index 100% rename from .circleci/unittest/helpers/coverage_run_parallel.py rename to .github/unittest/helpers/coverage_run_parallel.py diff --git a/.circleci/unittest/linux/scripts/10_nvidia.json b/.github/unittest/linux/scripts/10_nvidia.json similarity index 100% rename from .circleci/unittest/linux/scripts/10_nvidia.json rename to .github/unittest/linux/scripts/10_nvidia.json diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml similarity index 93% rename from .circleci/unittest/linux/scripts/environment.yml rename to .github/unittest/linux/scripts/environment.yml index f27bae7da6c..7125cfff04b 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -28,3 +28,5 @@ dependencies: - av - coverage - ray + - transformers + - ninja diff --git a/.circleci/unittest/linux/scripts/post_process.sh b/.github/unittest/linux/scripts/post_process.sh similarity index 100% rename from .circleci/unittest/linux/scripts/post_process.sh rename to .github/unittest/linux/scripts/post_process.sh diff --git a/.circleci/unittest/linux/scripts/run-clang-format.py b/.github/unittest/linux/scripts/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux/scripts/run-clang-format.py rename to .github/unittest/linux/scripts/run-clang-format.py diff --git a/.circleci/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh similarity index 81% rename from .circleci/unittest/linux/scripts/run_all.sh rename to .github/unittest/linux/scripts/run_all.sh index 85a46061b13..f682abe47f5 100755 --- a/.circleci/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -90,6 +90,7 @@ conda activate "${env_dir}" echo "installing gymnasium" pip3 install "gymnasium[atari,ale-py,accept-rom-license]" pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py +pip3 install mujoco -U # sanity check: remove? python3 -c """ @@ -121,10 +122,21 @@ fi git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu + else + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$CU_VERSION + fi else - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + printf "Failed to install pytorch" + exit 1 fi # smoke test @@ -134,7 +146,7 @@ python -c "import functorch" pip3 install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip3 install git+https://github.com/pytorch-labs/tensordict.git +pip3 install git+https://github.com/pytorch/tensordict.git printf "* Installing torchrl\n" python setup.py develop @@ -166,16 +178,16 @@ python -m torch.utils.collect_env #export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch - +export MAX_IDLE_COUNT=100 pytest test/smoke_test.py -v --durations 200 pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' if [ "${CU_VERSION:-}" != cpu ] ; then - python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test \ - --instafail --durations 200 --ignore test/test_rlhf.py + python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ + --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py else - python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test \ - --instafail --durations 200 --ignore test/test_rlhf.py --ignore test/test_distributed.py + python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ + --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py fi coverage combine diff --git a/.circleci/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml similarity index 100% rename from .circleci/unittest/linux_distributed/scripts/environment.yml rename to .github/unittest/linux_distributed/scripts/environment.yml diff --git a/.circleci/unittest/linux_distributed/scripts/install.sh b/.github/unittest/linux_distributed/scripts/install.sh similarity index 79% rename from .circleci/unittest/linux_distributed/scripts/install.sh rename to .github/unittest/linux_distributed/scripts/install.sh index d0f3f7a132e..95eda22aecb 100755 --- a/.circleci/unittest/linux_distributed/scripts/install.sh +++ b/.github/unittest/linux_distributed/scripts/install.sh @@ -28,9 +28,9 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu else - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi # smoke test @@ -40,7 +40,7 @@ python -c "import functorch" pip install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git printf "* Installing torchrl\n" python setup.py develop diff --git a/.circleci/unittest/linux_distributed/scripts/post_process.sh b/.github/unittest/linux_distributed/scripts/post_process.sh similarity index 100% rename from .circleci/unittest/linux_distributed/scripts/post_process.sh rename to .github/unittest/linux_distributed/scripts/post_process.sh diff --git a/.circleci/unittest/linux_distributed/scripts/run-clang-format.py b/.github/unittest/linux_distributed/scripts/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_distributed/scripts/run-clang-format.py rename to .github/unittest/linux_distributed/scripts/run-clang-format.py diff --git a/.circleci/unittest/linux_distributed/scripts/run_test.sh b/.github/unittest/linux_distributed/scripts/run_test.sh similarity index 57% rename from .circleci/unittest/linux_distributed/scripts/run_test.sh rename to .github/unittest/linux_distributed/scripts/run_test.sh index 88f09536622..211159de4e1 100755 --- a/.circleci/unittest/linux_distributed/scripts/run_test.sh +++ b/.github/unittest/linux_distributed/scripts/run_test.sh @@ -19,8 +19,8 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_stable/scripts/setup_env.sh b/.github/unittest/linux_distributed/scripts/setup_env.sh similarity index 89% rename from .circleci/unittest/linux_stable/scripts/setup_env.sh rename to .github/unittest/linux_distributed/scripts/setup_env.sh index 1d02f8ecd0c..501dbe1c914 100755 --- a/.circleci/unittest/linux_stable/scripts/setup_env.sh +++ b/.github/unittest/linux_distributed/scripts/setup_env.sh @@ -112,11 +112,14 @@ if [[ $OSTYPE != 'darwin'* ]]; then mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + elif [[ $PY_VERSION == *"3.11"* ]]; then + wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi echo "installing gymnasium" pip install "gymnasium[atari,accept-rom-license]" - pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py else pip install "gymnasium[atari,accept-rom-license]" - pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py fi diff --git a/.circleci/unittest/linux_examples/scripts/10_nvidia.json b/.github/unittest/linux_examples/scripts/10_nvidia.json similarity index 100% rename from .circleci/unittest/linux_examples/scripts/10_nvidia.json rename to .github/unittest/linux_examples/scripts/10_nvidia.json diff --git a/.circleci/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml similarity index 83% rename from .circleci/unittest/linux_examples/scripts/environment.yml rename to .github/unittest/linux_examples/scripts/environment.yml index 7a91696ca46..688921f826a 100644 --- a/.circleci/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -20,10 +20,12 @@ dependencies: - pyyaml - scipy - hydra-core - - tensorboard - imageio==2.26.0 - - wandb - dm_control - mlflow - av - coverage + - vmas + - transformers + - gym[atari] + - gym[accept-rom-license] diff --git a/.circleci/unittest/linux_examples/scripts/post_process.sh b/.github/unittest/linux_examples/scripts/post_process.sh similarity index 100% rename from .circleci/unittest/linux_examples/scripts/post_process.sh rename to .github/unittest/linux_examples/scripts/post_process.sh diff --git a/.circleci/unittest/linux_examples/scripts/run-clang-format.py b/.github/unittest/linux_examples/scripts/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_examples/scripts/run-clang-format.py rename to .github/unittest/linux_examples/scripts/run-clang-format.py diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh similarity index 72% rename from .circleci/unittest/linux_examples/scripts/run_all.sh rename to .github/unittest/linux_examples/scripts/run_all.sh index 4ff05b1bbee..6bf73ff1b95 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -7,29 +7,16 @@ set -v # ================================ Init ============================================== # -if [[ $OSTYPE != 'darwin'* ]]; then - apt-get update && apt-get upgrade -y - apt-get install -y vim git wget - - apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev - apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 - - if [ "${CU_VERSION:-}" == cpu ] ; then - # solves version `GLIBCXX_3.4.29' not found for tensorboard -# apt-get install -y gcc-4.9 - apt-get upgrade -y libstdc++6 - apt-get dist-upgrade -y - else - apt-get install -y g++ gcc - fi +apt-get update && apt-get upgrade -y +apt-get install -y vim git wget -fi +apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev +apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 +apt-get install -y g++ gcc patchelf this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -if [[ $OSTYPE != 'darwin'* ]]; then - # from cudagl docker image - cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json -fi +# from cudagl docker image +cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json # ==================================================================================== # @@ -69,8 +56,8 @@ conda activate "${env_dir}" printf "* Installing mujoco and related\n" mkdir -p $root_dir/.mujoco cd $root_dir/.mujoco/ -wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz -tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz tar -xf mujoco210-linux-x86_64.tar.gz cd "${root_dir}" @@ -80,9 +67,16 @@ printf "* Installing dependencies (except PyTorch)\n" echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" cat "${this_dir}/environment.yml" +export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 +export DISPLAY=unix:0.0 +#export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin +export SDL_VIDEODRIVER=dummy +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl + conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ DISPLAY=unix:0.0 \ - MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=egl \ @@ -95,6 +89,19 @@ conda env update --file "${this_dir}/environment.yml" --prune conda deactivate conda activate "${env_dir}" +# install d4rl +pip install free-mujoco-py +pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl + +# TODO: move this down -- will break torchrl installation +conda install -y -c conda-forge libstdcxx-ng=12 +## find libstdc +STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) +conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC + +# compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) +python -c """import gym;import d4rl""" + # install ale-py: manylinux names are broken for CentOS so we need to manually download and # rename them PY_VERSION=$(python --version) @@ -115,6 +122,11 @@ elif [[ $PY_VERSION == *"3.10"* ]]; then mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.11"* ]]; then + wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi pip install "gymnasium[atari,accept-rom-license]" @@ -134,7 +146,7 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CU_VERSION +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION # smoke test python -c "import functorch" @@ -143,7 +155,7 @@ python -c "import functorch" pip install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git printf "* Installing torchrl\n" python setup.py develop diff --git a/.circleci/unittest/linux_examples/scripts/run_local.sh b/.github/unittest/linux_examples/scripts/run_local.sh similarity index 59% rename from .circleci/unittest/linux_examples/scripts/run_local.sh rename to .github/unittest/linux_examples/scripts/run_local.sh index 9fd521add7a..75bc84e4ca8 100755 --- a/.circleci/unittest/linux_examples/scripts/run_local.sh +++ b/.github/unittest/linux_examples/scripts/run_local.sh @@ -1,9 +1,10 @@ #!/bin/bash set -e +set -v # Read script from line 29 -filename=".circleci/unittest/linux_examples/scripts/run_test.sh" +filename=".github/unittest/linux_examples/scripts/run_test.sh" start_line=29 script=$(tail -n +$start_line "$filename") script="set -e"$'\n'"$script" @@ -11,8 +12,8 @@ script="set -e"$'\n'"$script" # Replace "cuda:0" with "cpu" script="${script//cuda:0/cpu}" -# Remove any instances of ".circleci/unittest/helpers/coverage_run_parallel.py" -script="${script//.circleci\/unittest\/helpers\/coverage_run_parallel.py}" +# Remove any instances of ".github/unittest/helpers/coverage_run_parallel.py" +script="${script//.github\/unittest\/helpers\/coverage_run_parallel.py}" script="${script//coverage combine}" script="${script//coverage xml -i}" diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh new file mode 100755 index 00000000000..4d58117f58c --- /dev/null +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -0,0 +1,291 @@ +#!/usr/bin/env bash + +# Leave blank as code needs to start on line 29 for run_local.sh +# +# +# +# +# +# +# + +set -e +set -v + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 +#python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 + +# ==================================================================================== # +# ================================ gym 0.23 ========================================== # + +# With batched environments +python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 \ + optim.device=cuda:0 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 \ + optim.device=cuda:0 \ + logger.backend= + +# ==================================================================================== # +# ================================ Gymnasium ========================================= # + +python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \ + env.env_name=HalfCheetah-v4 \ + collector.total_frames=40 \ + collector.frames_per_batch=20 \ + loss.mini_batch_size=10 \ + loss.ppo_epochs=2 \ + logger.backend= \ + logger.test_interval=10 +python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + loss.mini_batch_size=20 \ + loss.ppo_epochs=2 \ + logger.backend= \ + logger.test_interval=10 +python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.collector_device=cuda:0 \ + network.device=cuda:0 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + logger.backend= +# record_video=True \ +# record_frames=4 \ +python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_mujoco.py \ + env.env_name=HalfCheetah-v4 \ + collector.total_frames=40 \ + collector.frames_per_batch=20 \ + loss.mini_batch_size=10 \ + logger.backend= \ + logger.test_interval=40 +python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_atari.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + loss.mini_batch_size=20 \ + logger.backend= \ + logger.test_interval=40 +python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=4 \ + env_per_collector=2 \ + collector_device=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=4 \ + env_per_collector=2 \ + collector_device=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.collector_device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + network.device=cuda:0 \ + logger.backend= +# logger.record_video=True \ +# logger.record_frames=4 \ +python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ + total_frames=200 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=200 \ + num_workers=4 \ + env_per_collector=2 \ + collector_device=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 \ + rssm_hidden_dim=17 +python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.num_workers=4 \ + collector.env_per_collector=2 \ + collector.collector_device=cuda:0 \ + network.device=cuda:0 \ + logger.mode=offline \ + env.name=Pendulum-v1 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ + total_frames=48 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=4 \ + env_per_collector=2 \ + collector_device=cuda:0 \ + device=cuda:0 \ + mode=offline \ + logger= + +# With single envs +python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ + total_frames=200 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=200 \ + num_workers=2 \ + env_per_collector=1 \ + collector_device=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 \ + rssm_hidden_dim=17 +python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=1 \ + collector.collector_device=cuda:0 \ + network.device=cuda:0 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + logger.backend= +# record_video=True \ +# record_frames=4 \ +python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_device=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_device=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=1 \ + collector.collector_device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + network.device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ + total_frames=48 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + mode=offline \ + device=cuda:0 \ + collector_device=cuda:0 \ + logger= +python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.num_workers=2 \ + collector.env_per_collector=1 \ + logger.mode=offline \ + collector.collector_device=cuda:0 \ + env.name=Pendulum-v1 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \ + collector.n_iters=2 \ + collector.frames_per_batch=200 \ + train.num_epochs=3 \ + train.minibatch_size=100 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/maddpg_iddpg.py \ + collector.n_iters=2 \ + collector.frames_per_batch=200 \ + train.num_epochs=3 \ + train.minibatch_size=100 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/iql.py \ + collector.n_iters=2 \ + collector.frames_per_batch=200 \ + train.num_epochs=3 \ + train.minibatch_size=100 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/qmix_vdn.py \ + collector.n_iters=2 \ + collector.frames_per_batch=200 \ + train.num_epochs=3 \ + train.minibatch_size=100 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac.py \ + collector.n_iters=2 \ + collector.frames_per_batch=200 \ + train.num_epochs=3 \ + train.minibatch_size=100 \ + logger.backend= + +python .github/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100 + +## RLHF +# RLHF tests are executed in the dedicated workflow + +coverage combine +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_brax/environment.yml b/.github/unittest/linux_libs/scripts_brax/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_brax/environment.yml rename to .github/unittest/linux_libs/scripts_brax/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_brax/install.sh b/.github/unittest/linux_libs/scripts_brax/install.sh similarity index 79% rename from .circleci/unittest/linux_libs/scripts_brax/install.sh rename to .github/unittest/linux_libs/scripts_brax/install.sh index 1b3f34cb0bd..b3a42967935 100755 --- a/.circleci/unittest/linux_libs/scripts_brax/install.sh +++ b/.github/unittest/linux_libs/scripts_brax/install.sh @@ -30,13 +30,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall --progress-bar off + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall --progress-bar off else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall --progress-bar off + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall --progress-bar off fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git --progress-bar off +pip install git+https://github.com/pytorch/tensordict.git --progress-bar off # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_brax/post_process.sh b/.github/unittest/linux_libs/scripts_brax/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_brax/post_process.sh rename to .github/unittest/linux_libs/scripts_brax/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_brax/run-clang-format.py b/.github/unittest/linux_libs/scripts_brax/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_brax/run-clang-format.py rename to .github/unittest/linux_libs/scripts_brax/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_brax/run_all.sh b/.github/unittest/linux_libs/scripts_brax/run_all.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_brax/run_all.sh rename to .github/unittest/linux_libs/scripts_brax/run_all.sh diff --git a/.circleci/unittest/linux_libs/scripts_brax/run_test.sh b/.github/unittest/linux_libs/scripts_brax/run_test.sh similarity index 86% rename from .circleci/unittest/linux_libs/scripts_brax/run_test.sh rename to .github/unittest/linux_libs/scripts_brax/run_test.sh index 9fa45ae4737..6a4dac48331 100755 --- a/.circleci/unittest/linux_libs/scripts_brax/run_test.sh +++ b/.github/unittest/linux_libs/scripts_brax/run_test.sh @@ -30,6 +30,6 @@ python -c "import brax.envs" python -c "import jax" python3 -c 'import torch;t = torch.ones([2,2], device="cuda:0");print(t);print("tensor device:" + str(t.device))' -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_brax/setup_env.sh b/.github/unittest/linux_libs/scripts_brax/setup_env.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_brax/setup_env.sh rename to .github/unittest/linux_libs/scripts_brax/setup_env.sh diff --git a/.circleci/unittest/linux_libs/scripts_d4rl/environment.yml b/.github/unittest/linux_libs/scripts_d4rl/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_d4rl/environment.yml rename to .github/unittest/linux_libs/scripts_d4rl/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_d4rl/install.sh b/.github/unittest/linux_libs/scripts_d4rl/install.sh similarity index 83% rename from .circleci/unittest/linux_libs/scripts_d4rl/install.sh rename to .github/unittest/linux_libs/scripts_d4rl/install.sh index 437900b3323..feb922d14b8 100755 --- a/.circleci/unittest/linux_libs/scripts_d4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/install.sh @@ -33,13 +33,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_d4rl/post_process.sh b/.github/unittest/linux_libs/scripts_d4rl/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_d4rl/post_process.sh rename to .github/unittest/linux_libs/scripts_d4rl/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_d4rl/run-clang-format.py b/.github/unittest/linux_libs/scripts_d4rl/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_d4rl/run-clang-format.py rename to .github/unittest/linux_libs/scripts_d4rl/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_d4rl/run_test.sh b/.github/unittest/linux_libs/scripts_d4rl/run_test.sh similarity index 59% rename from .circleci/unittest/linux_libs/scripts_d4rl/run_test.sh rename to .github/unittest/linux_libs/scripts_d4rl/run_test.sh index f453290b7fd..3723399a859 100755 --- a/.circleci/unittest/linux_libs/scripts_d4rl/run_test.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/run_test.sh @@ -11,7 +11,7 @@ ln -s /usr/bin/swig3.0 /usr/bin/swig # we install d4rl here bc env variables have been updated git clone https://github.com/Farama-Foundation/d4rl.git cd d4rl -pip3 install -U 'mujoco-py<2.1,>=2.0' +#pip3 install -U 'mujoco-py<2.1,>=2.0' pip3 install -U "gym[classic_control,atari,accept-rom-license]"==0.23 pip3 install -U six pip install -e . @@ -37,6 +37,25 @@ conda deactivate && conda activate ./env # this workflow only tests the libs python -c "import gym, d4rl" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips coverage combine coverage xml -i + +## check what happens if we update gym +#pip install gym -U +#python -c """ +#from torchrl.data.datasets import D4RLExperienceReplay +#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=True) +#for batch in data: +# print(batch) +# break +# +#data = D4RLExperienceReplay('halfcheetah-medium-v2', batch_size=10, from_env=False, direct_download=False) +#for batch in data: +# print(batch) +# break +# +#import d4rl +#import gym +#gym.make('halfcheetah-medium-v2') +#""" diff --git a/.circleci/unittest/linux_libs/scripts_d4rl/setup_env.sh b/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh similarity index 86% rename from .circleci/unittest/linux_libs/scripts_d4rl/setup_env.sh rename to .github/unittest/linux_libs/scripts_d4rl/setup_env.sh index c28f1a4b350..c1985c239ee 100755 --- a/.circleci/unittest/linux_libs/scripts_d4rl/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh @@ -6,6 +6,7 @@ # Do not install PyTorch and torchvision here, otherwise they also get cached. set -e +set -v this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" @@ -39,6 +40,12 @@ if [ ! -d "${env_dir}" ]; then fi conda activate "${env_dir}" +#pip3 uninstall cython -y +#pip uninstall cython -y +#conda uninstall cython -y +pip3 install "cython<3" +conda install -c anaconda cython="<3.0.0" -y + # 3. Install mujoco printf "* Installing mujoco and related\n" @@ -47,12 +54,15 @@ cd $root_dir/.mujoco/ #wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz #tar -xf mujoco-2.1.1-linux-x86_64.tar.gz #wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -wget https://www.roboti.us/download/mujoco200_linux.zip +wget https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/mujoco200_linux.zip unzip mujoco200_linux.zip -wget https://www.roboti.us/file/mjkey.txt +wget https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/mjkey.txt cp mjkey.txt ./mujoco200_linux/bin/ # install mujoco-py locally git clone https://github.com/vmoens/mujoco-py.git +cd mujoco-py +git checkout v2.0.2.1 +pip install -e . cd $this_dir # 4. Install Conda dependencies @@ -60,7 +70,7 @@ printf "* Installing dependencies (except PyTorch)\n" echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" cat "${this_dir}/environment.yml" -pip install pip --upgrade +pip3 install pip --upgrade # 5. env variables if [[ $OSTYPE == 'darwin'* ]]; then diff --git a/.circleci/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_envpool/environment.yml rename to .github/unittest/linux_libs/scripts_envpool/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_envpool/install.sh b/.github/unittest/linux_libs/scripts_envpool/install.sh similarity index 81% rename from .circleci/unittest/linux_libs/scripts_envpool/install.sh rename to .github/unittest/linux_libs/scripts_envpool/install.sh index 5899209cc46..c62a2de25fb 100755 --- a/.circleci/unittest/linux_libs/scripts_envpool/install.sh +++ b/.github/unittest/linux_libs/scripts_envpool/install.sh @@ -28,16 +28,16 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu118 + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 fi # smoke test python -c "import functorch" # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch/tensordict printf "* Installing torchrl\n" python setup.py develop diff --git a/.circleci/unittest/linux_libs/scripts_envpool/post_process.sh b/.github/unittest/linux_libs/scripts_envpool/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_envpool/post_process.sh rename to .github/unittest/linux_libs/scripts_envpool/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_envpool/run-clang-format.py b/.github/unittest/linux_libs/scripts_envpool/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_envpool/run-clang-format.py rename to .github/unittest/linux_libs/scripts_envpool/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_envpool/run_test.sh b/.github/unittest/linux_libs/scripts_envpool/run_test.sh similarity index 84% rename from .circleci/unittest/linux_libs/scripts_envpool/run_test.sh rename to .github/unittest/linux_libs/scripts_envpool/run_test.sh index b00c79527da..289adf454e7 100755 --- a/.circleci/unittest/linux_libs/scripts_envpool/run_test.sh +++ b/.github/unittest/linux_libs/scripts_envpool/run_test.sh @@ -27,6 +27,6 @@ export MKL_THREADING_LAYER=GNU # this workflow only tests the libs python -c "import envpool" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestEnvPool --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestEnvPool --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_envpool/setup_env.sh b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh similarity index 87% rename from .circleci/unittest/linux_libs/scripts_envpool/setup_env.sh rename to .github/unittest/linux_libs/scripts_envpool/setup_env.sh index 3f405c584ce..bb5c09079ea 100755 --- a/.circleci/unittest/linux_libs/scripts_envpool/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh @@ -74,6 +74,11 @@ if [[ $OSTYPE != 'darwin'* ]]; then mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + elif [[ $PY_VERSION == *"3.11"* ]]; then + wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi echo "installing gym" # envpool does not currently work with gymnasium diff --git a/.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh similarity index 88% rename from .circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh rename to .github/unittest/linux_libs/scripts_gym/batch_scripts.sh index abb76df994b..321da982d2e 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -50,6 +50,8 @@ do conda activate ./cloned_env echo "Testing gym version: ${GYM_VERSION}" + # handling https://github.com/openai/gym/issues/3202 + pip3 install wheel==0.38.4 pip3 install gym==$GYM_VERSION $DIR/run_test.sh @@ -67,6 +69,7 @@ do conda activate ./cloned_env echo "Testing gym version: ${GYM_VERSION}" + pip3 install wheel==0.38.4 pip3 install 'gym[atari]'==$GYM_VERSION pip3 install ale-py==0.7 $DIR/run_test.sh @@ -144,12 +147,18 @@ do mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + elif [[ $PY_VERSION == *"3.11"* ]]; then + wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi pip install gymnasium[atari] else pip install gymnasium[atari] fi pip install mo-gymnasium + pip install gymnasium-robotics $DIR/run_test.sh diff --git a/.circleci/unittest/linux_libs/scripts_gym/environment.yml b/.github/unittest/linux_libs/scripts_gym/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_gym/environment.yml rename to .github/unittest/linux_libs/scripts_gym/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_gym/install.sh b/.github/unittest/linux_libs/scripts_gym/install.sh similarity index 96% rename from .circleci/unittest/linux_libs/scripts_gym/install.sh rename to .github/unittest/linux_libs/scripts_gym/install.sh index 959269c1b16..718e4f37e3a 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/install.sh +++ b/.github/unittest/linux_libs/scripts_gym/install.sh @@ -46,7 +46,7 @@ fi pip install -U --force-reinstall charset-normalizer # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_gym/post_process.sh b/.github/unittest/linux_libs/scripts_gym/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_gym/post_process.sh rename to .github/unittest/linux_libs/scripts_gym/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_gym/run-clang-format.py b/.github/unittest/linux_libs/scripts_gym/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_gym/run-clang-format.py rename to .github/unittest/linux_libs/scripts_gym/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_gym/run_test.sh b/.github/unittest/linux_libs/scripts_gym/run_test.sh similarity index 58% rename from .circleci/unittest/linux_libs/scripts_gym/run_test.sh rename to .github/unittest/linux_libs/scripts_gym/run_test.sh index b710b77c9b9..2e5860468c3 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/run_test.sh +++ b/.github/unittest/linux_libs/scripts_gym/run_test.sh @@ -17,11 +17,11 @@ lib_dir="${env_dir}/lib" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym' +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym' export DISPLAY=':99.0' Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym" --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_gym/setup_env.sh b/.github/unittest/linux_libs/scripts_gym/setup_env.sh similarity index 91% rename from .circleci/unittest/linux_libs/scripts_gym/setup_env.sh rename to .github/unittest/linux_libs/scripts_gym/setup_env.sh index c17bd82d120..8804370aa6d 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_gym/setup_env.sh @@ -57,7 +57,7 @@ mkdir -p mujoco_py/binaries/linux \ && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ && tar -xf mujoco.tar.gz -C mujoco_py/binaries/linux \ && rm mujoco.tar.gz -wget https://www.roboti.us/file/mjkey.txt +wget https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/mjkey.txt cp mjkey.txt mujoco_py/binaries/ pip install -e . cd .. @@ -79,9 +79,8 @@ conda env config vars set \ NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ MUJOCO_PY_MJKEY_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/mjkey.txt \ MUJOCO_PY_MUJOCO_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/linux/mujoco210 \ - LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin - -# LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/src/nvidia-470.63.01 \ + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/pytorch/rl/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin +# LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin # make env variables apparent conda deactivate && conda activate "${env_dir}" diff --git a/.circleci/unittest/linux_libs/scripts_habitat/10_nvidia.json b/.github/unittest/linux_libs/scripts_habitat/10_nvidia.json similarity index 100% rename from .circleci/unittest/linux_libs/scripts_habitat/10_nvidia.json rename to .github/unittest/linux_libs/scripts_habitat/10_nvidia.json diff --git a/.circleci/unittest/linux_libs/scripts_habitat/environment.yml b/.github/unittest/linux_libs/scripts_habitat/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_habitat/environment.yml rename to .github/unittest/linux_libs/scripts_habitat/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_habitat/install.sh b/.github/unittest/linux_libs/scripts_habitat/install.sh similarity index 81% rename from .circleci/unittest/linux_libs/scripts_habitat/install.sh rename to .github/unittest/linux_libs/scripts_habitat/install.sh index 82170d7fd8b..316cf9e3225 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/install.sh +++ b/.github/unittest/linux_libs/scripts_habitat/install.sh @@ -20,10 +20,10 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall # install tensordict -pip3 install git+https://github.com/pytorch-labs/tensordict.git +pip3 install git+https://github.com/pytorch/tensordict.git # smoke test python3 -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_habitat/post_process.sh b/.github/unittest/linux_libs/scripts_habitat/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_habitat/post_process.sh rename to .github/unittest/linux_libs/scripts_habitat/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_habitat/run-clang-format.py b/.github/unittest/linux_libs/scripts_habitat/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_habitat/run-clang-format.py rename to .github/unittest/linux_libs/scripts_habitat/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_habitat/run_all.sh b/.github/unittest/linux_libs/scripts_habitat/run_all.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_habitat/run_all.sh rename to .github/unittest/linux_libs/scripts_habitat/run_all.sh diff --git a/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh b/.github/unittest/linux_libs/scripts_habitat/run_test.sh similarity index 90% rename from .circleci/unittest/linux_libs/scripts_habitat/run_test.sh rename to .github/unittest/linux_libs/scripts_habitat/run_test.sh index 1f916fea9c1..5c9becfe832 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.github/unittest/linux_libs/scripts_habitat/run_test.sh @@ -47,6 +47,6 @@ env = HabitatEnv('HabitatRenderPick-v0') env.reset() """ -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestHabitat --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestHabitat --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh similarity index 80% rename from .circleci/unittest/linux_libs/scripts_habitat/setup_env.sh rename to .github/unittest/linux_libs/scripts_habitat/setup_env.sh index 9fc68255fcd..d287f8a5977 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -39,9 +39,24 @@ if [ ! -d "${env_dir}" ]; then conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" fi conda activate "${env_dir}" +#pip3 uninstall cython -y +#pip uninstall cython -y +#conda uninstall cython -y +pip3 install "cython<3" +conda install -c anaconda cython="<3.0.0" -y -# 3. Install Conda dependencies +# 3. Install git LFS +mkdir git_lfs +wget https://github.com/git-lfs/git-lfs/releases/download/v2.9.0/git-lfs-linux-amd64-v2.9.0.tar.gz --directory-prefix git_lfs +cd git_lfs +tar -xf git-lfs-linux-amd64-v2.9.0.tar.gz +chmod 755 install.sh +./install.sh +cd .. +git lfs install + +# 4. Install Conda dependencies printf "* Installing dependencies (except PyTorch)\n" echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" cat "${this_dir}/environment.yml" diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/environment.yml b/.github/unittest/linux_libs/scripts_jumanji/environment.yml similarity index 88% rename from .circleci/unittest/linux_libs/scripts_jumanji/environment.yml rename to .github/unittest/linux_libs/scripts_jumanji/environment.yml index 12714559a43..fa16f027ee0 100644 --- a/.circleci/unittest/linux_libs/scripts_jumanji/environment.yml +++ b/.github/unittest/linux_libs/scripts_jumanji/environment.yml @@ -18,5 +18,3 @@ dependencies: - scipy - hydra-core - jumanji - - jax<=0.4.10 - - jaxlib<=0.4.10 diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh b/.github/unittest/linux_libs/scripts_jumanji/install.sh similarity index 81% rename from .circleci/unittest/linux_libs/scripts_jumanji/install.sh rename to .github/unittest/linux_libs/scripts_jumanji/install.sh index 91671e8d985..ee6c747315c 100755 --- a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/install.sh @@ -30,13 +30,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/post_process.sh b/.github/unittest/linux_libs/scripts_jumanji/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_jumanji/post_process.sh rename to .github/unittest/linux_libs/scripts_jumanji/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/run-clang-format.py b/.github/unittest/linux_libs/scripts_jumanji/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_jumanji/run-clang-format.py rename to .github/unittest/linux_libs/scripts_jumanji/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/run_test.sh b/.github/unittest/linux_libs/scripts_jumanji/run_test.sh similarity index 83% rename from .circleci/unittest/linux_libs/scripts_jumanji/run_test.sh rename to .github/unittest/linux_libs/scripts_jumanji/run_test.sh index 4b9bb270727..67f86ed73ee 100755 --- a/.circleci/unittest/linux_libs/scripts_jumanji/run_test.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/run_test.sh @@ -28,6 +28,6 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON # this workflow only tests the libs python -c "import jumanji" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh b/.github/unittest/linux_libs/scripts_jumanji/setup_env.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh rename to .github/unittest/linux_libs/scripts_jumanji/setup_env.sh diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml new file mode 100644 index 00000000000..76f97355f7a --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml @@ -0,0 +1,23 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - swig + - pip: + - cloudpickle + - gym + - gym-notices + - importlib-metadata + - six + - zipp + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - autorom[accept-rom-license] + - pettingzoo[all]==1.24.1 diff --git a/.circleci/unittest/linux_libs/scripts_vmas/install.sh b/.github/unittest/linux_libs/scripts_pettingzoo/install.sh similarity index 81% rename from .circleci/unittest/linux_libs/scripts_vmas/install.sh rename to .github/unittest/linux_libs/scripts_pettingzoo/install.sh index cb36c7cc48a..0c7bc8f402b 100755 --- a/.circleci/unittest/linux_libs/scripts_vmas/install.sh +++ b/.github/unittest/linux_libs/scripts_pettingzoo/install.sh @@ -30,13 +30,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_rlhf/post_process.sh b/.github/unittest/linux_libs/scripts_pettingzoo/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_rlhf/post_process.sh rename to .github/unittest/linux_libs/scripts_pettingzoo/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_rlhf/run-clang-format.py b/.github/unittest/linux_libs/scripts_pettingzoo/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_rlhf/run-clang-format.py rename to .github/unittest/linux_libs/scripts_pettingzoo/run-clang-format.py diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh new file mode 100755 index 00000000000..1cdb653ede8 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# this workflow only tests the libs +python -c "import pettingzoo" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestPettingZoo --error-for-skips +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh b/.github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh new file mode 100755 index 00000000000..a3f833112a9 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +# 5. Download atari roms +autorom_dir="${env_dir}/lib/python${PYTHON_VERSION}/site-packages/AutoROM/roms" +multi_atari_rom_dir="${env_dir}/lib/python${PYTHON_VERSION}/site-packages/multi_agent_ale_py/roms" +ln -s "${autorom_dir}" "${multi_atari_rom_dir}" diff --git a/.circleci/unittest/linux_libs/scripts_rlhf/environment.yml b/.github/unittest/linux_libs/scripts_rlhf/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_rlhf/environment.yml rename to .github/unittest/linux_libs/scripts_rlhf/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh similarity index 83% rename from .circleci/unittest/linux_libs/scripts_rlhf/install.sh rename to .github/unittest/linux_libs/scripts_rlhf/install.sh index 76c10f36e6c..25a73fd6dff 100755 --- a/.circleci/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh @@ -33,13 +33,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/post_process.sh b/.github/unittest/linux_libs/scripts_rlhf/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_sklearn/post_process.sh rename to .github/unittest/linux_libs/scripts_rlhf/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/run-clang-format.py b/.github/unittest/linux_libs/scripts_rlhf/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_sklearn/run-clang-format.py rename to .github/unittest/linux_libs/scripts_rlhf/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_rlhf/run_test.sh b/.github/unittest/linux_libs/scripts_rlhf/run_test.sh similarity index 50% rename from .circleci/unittest/linux_libs/scripts_rlhf/run_test.sh rename to .github/unittest/linux_libs/scripts_rlhf/run_test.sh index 641838e6612..bdbe1b18ff1 100755 --- a/.circleci/unittest/linux_libs/scripts_rlhf/run_test.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/run_test.sh @@ -21,6 +21,15 @@ conda deactivate && conda activate ./env python -c "import transformers, datasets" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips + +python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \ + sys.device=cuda:0 sys.ref_device=cuda:0 \ + model.name_or_path=gpt2 train.max_epochs=2 \ + data.batch_size=2 train.ppo.ppo_batch_size=2 \ + train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \ + train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \ + data.block_size=110 io.logger=csv + coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_rlhf/setup_env.sh b/.github/unittest/linux_libs/scripts_rlhf/setup_env.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_rlhf/setup_env.sh rename to .github/unittest/linux_libs/scripts_rlhf/setup_env.sh diff --git a/.circleci/unittest/linux_stable/scripts/environment.yml b/.github/unittest/linux_libs/scripts_robohive/environment.yml similarity index 64% rename from .circleci/unittest/linux_stable/scripts/environment.yml rename to .github/unittest/linux_libs/scripts_robohive/environment.yml index 3ea4037802e..705b8522c92 100644 --- a/.circleci/unittest/linux_stable/scripts/environment.yml +++ b/.github/unittest/linux_libs/scripts_robohive/environment.yml @@ -3,9 +3,10 @@ channels: - defaults dependencies: - pip - - ninja - protobuf - pip: + # Initial version is required to install Atari ROMS in setup_env.sh + - gym==0.13 - hypothesis - future - cloudpickle @@ -17,15 +18,11 @@ dependencies: - pytest-mock - pytest-instafail - pytest-rerunfailures + - pytest-error-for-skips - expecttest - pyyaml - scipy - hydra-core - - tensorboard - - imageio==2.26.0 - - wandb - - dm_control - - mlflow - - av - - coverage - - ray + - patchelf + - mujoco==2.3.3 + - dm_control==1.0.11 diff --git a/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh b/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh new file mode 100755 index 00000000000..68fe922ec5d --- /dev/null +++ b/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget libglew-dev libx11-dev x11proto-dev g++ gcc libosmesa6-dev + +set -e +set -v + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +#apt-get update -y && apt-get install git wget gcc g++ -y + +if [ "${CU_VERSION:-}" == cpu ] ; then + cudatoolkit="cpuonly" + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" + cudatoolkit="cudatoolkit=${version}" +fi + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop +python -c "import torchrl" + +# Extracted from run_test.sh to run once. + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 20 + +# let's make sure we have a GPU at our disposal +python -c """ +import torch +devcount = torch.cuda.device_count() +assert devcount +print('device count', devcount) +""" + +echo $MUJOCO_GL +echo $sim_backend + +sim_backend=MUJOCO MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 20 -k "robohive" --error-for-skips +coverage combine +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_vmas/post_process.sh b/.github/unittest/linux_libs/scripts_robohive/post_process.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_vmas/post_process.sh rename to .github/unittest/linux_libs/scripts_robohive/post_process.sh diff --git a/.circleci/unittest/linux_libs/scripts_vmas/run-clang-format.py b/.github/unittest/linux_libs/scripts_robohive/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_libs/scripts_vmas/run-clang-format.py rename to .github/unittest/linux_libs/scripts_robohive/run-clang-format.py diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh new file mode 100755 index 00000000000..50625f1e906 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +#git clone https://github.com/vmoens/mujoco-py.git +#cd mujoco-py +#git checkout aws_fix2 +#mkdir -p mujoco_py/binaries/linux \ +# && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ +# && tar -xf mujoco.tar.gz -C mujoco_py/binaries/linux \ +# && rm mujoco.tar.gz +#wget https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/mjkey.txt +#cp mjkey.txt mujoco_py/binaries/ +#pip install -e . +#cd .. + +#cd $this_dir + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +export MUJOCO_GL=egl +conda env config vars set \ + MUJOCO_GL=egl \ + SDL_VIDEODRIVER=dummy \ + DISPLAY=unix:0.0 \ + PYOPENGL_PLATFORM=egl \ + NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ + sim_backend=MUJOCO + +# make env variables apparent +conda deactivate && conda activate "${env_dir}" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +pip install git+https://github.com/vikashplus/robohive@main diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/environment.yml b/.github/unittest/linux_libs/scripts_sklearn/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_sklearn/environment.yml rename to .github/unittest/linux_libs/scripts_sklearn/environment.yml diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/install.sh b/.github/unittest/linux_libs/scripts_sklearn/install.sh similarity index 83% rename from .circleci/unittest/linux_libs/scripts_sklearn/install.sh rename to .github/unittest/linux_libs/scripts_sklearn/install.sh index 437900b3323..feb922d14b8 100755 --- a/.circleci/unittest/linux_libs/scripts_sklearn/install.sh +++ b/.github/unittest/linux_libs/scripts_sklearn/install.sh @@ -33,13 +33,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release # conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh b/.github/unittest/linux_libs/scripts_sklearn/post_process.sh similarity index 100% rename from .circleci/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh rename to .github/unittest/linux_libs/scripts_sklearn/post_process.sh diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/run-clang-format.py b/.github/unittest/linux_libs/scripts_sklearn/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_olddeps/scripts_gym_0_13/run-clang-format.py rename to .github/unittest/linux_libs/scripts_sklearn/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/run_test.sh b/.github/unittest/linux_libs/scripts_sklearn/run_test.sh similarity index 77% rename from .circleci/unittest/linux_libs/scripts_sklearn/run_test.sh rename to .github/unittest/linux_libs/scripts_sklearn/run_test.sh index 56830393799..ef9b119cbc0 100755 --- a/.circleci/unittest/linux_libs/scripts_sklearn/run_test.sh +++ b/.github/unittest/linux_libs/scripts_sklearn/run_test.sh @@ -22,6 +22,6 @@ conda deactivate && conda activate ./env # this workflow only tests the libs python -c "import sklearn, pandas" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenML --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenML --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh b/.github/unittest/linux_libs/scripts_sklearn/setup_env.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh rename to .github/unittest/linux_libs/scripts_sklearn/setup_env.sh diff --git a/.github/unittest/linux_libs/scripts_smacv2/environment.yml b/.github/unittest/linux_libs/scripts_smacv2/environment.yml new file mode 100644 index 00000000000..d1e1e1f5edc --- /dev/null +++ b/.github/unittest/linux_libs/scripts_smacv2/environment.yml @@ -0,0 +1,21 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - cloudpickle + - gym + - gym-notices + - importlib-metadata + - zipp + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - numpy==1.23.0 + - git+https://github.com/oxwhirl/smacv2.git diff --git a/.github/unittest/linux_libs/scripts_smacv2/install.sh b/.github/unittest/linux_libs/scripts_smacv2/install.sh new file mode 100755 index 00000000000..0c7bc8f402b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_smacv2/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop +python -c "import torchrl" diff --git a/.circleci/unittest/linux_optdeps/scripts/post_process.sh b/.github/unittest/linux_libs/scripts_smacv2/post_process.sh similarity index 100% rename from .circleci/unittest/linux_optdeps/scripts/post_process.sh rename to .github/unittest/linux_libs/scripts_smacv2/post_process.sh diff --git a/.circleci/unittest/linux_optdeps/scripts/run-clang-format.py b/.github/unittest/linux_libs/scripts_smacv2/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_optdeps/scripts/run-clang-format.py rename to .github/unittest/linux_libs/scripts_smacv2/run-clang-format.py diff --git a/.github/unittest/linux_libs/scripts_smacv2/run_test.sh b/.github/unittest/linux_libs/scripts_smacv2/run_test.sh new file mode 100755 index 00000000000..65fd7462df3 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_smacv2/run_test.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" +export SC2PATH="${root_dir}/StarCraftII" +echo 'SC2PATH is set to ' "$SC2PATH" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# this workflow only tests the libs +python -c "import smacv2" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestSmacv2 --error-for-skips +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_smacv2/setup_env.sh b/.github/unittest/linux_libs/scripts_smacv2/setup_env.sh new file mode 100755 index 00000000000..04080cc8932 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_smacv2/setup_env.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +# 5. Install StarCraft 2 with SMACv2 maps +starcraft_path="${root_dir}/StarCraftII" +map_dir="${starcraft_path}/Maps" +printf "* Installing StarCraft 2 and SMACv2 maps into ${starcraft_path}\n" +cd "${root_dir}" +# TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. +# e.g adding this into the image learn( which one is used and how it is maintained) +wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip +# The archive contains StarCraftII folder. Password comes from the documentation. +unzip -qo -P iagreetotheeula SC2.4.10.zip +mkdir -p "${map_dir}" +# Install Maps +wget https://github.com/oxwhirl/smacv2/releases/download/maps/SMAC_Maps.zip +unzip SMAC_Maps.zip +mkdir "${map_dir}/SMAC_Maps" +mv *.SC2Map "${map_dir}/SMAC_Maps" +printf "StarCraft II and SMAC are installed." diff --git a/.circleci/unittest/linux_libs/scripts_vmas/environment.yml b/.github/unittest/linux_libs/scripts_vmas/environment.yml similarity index 100% rename from .circleci/unittest/linux_libs/scripts_vmas/environment.yml rename to .github/unittest/linux_libs/scripts_vmas/environment.yml diff --git a/.github/unittest/linux_libs/scripts_vmas/install.sh b/.github/unittest/linux_libs/scripts_vmas/install.sh new file mode 100755 index 00000000000..0c7bc8f402b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_vmas/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop +python -c "import torchrl" diff --git a/.circleci/unittest/linux_stable/scripts/post_process.sh b/.github/unittest/linux_libs/scripts_vmas/post_process.sh similarity index 100% rename from .circleci/unittest/linux_stable/scripts/post_process.sh rename to .github/unittest/linux_libs/scripts_vmas/post_process.sh diff --git a/.circleci/unittest/linux_stable/scripts/run-clang-format.py b/.github/unittest/linux_libs/scripts_vmas/run-clang-format.py similarity index 100% rename from .circleci/unittest/linux_stable/scripts/run-clang-format.py rename to .github/unittest/linux_libs/scripts_vmas/run-clang-format.py diff --git a/.circleci/unittest/linux_libs/scripts_vmas/run_test.sh b/.github/unittest/linux_libs/scripts_vmas/run_test.sh similarity index 81% rename from .circleci/unittest/linux_libs/scripts_vmas/run_test.sh rename to .github/unittest/linux_libs/scripts_vmas/run_test.sh index dc2bd903c5f..66934039783 100755 --- a/.circleci/unittest/linux_libs/scripts_vmas/run_test.sh +++ b/.github/unittest/linux_libs/scripts_vmas/run_test.sh @@ -25,6 +25,6 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON # this workflow only tests the libs python -c "import vmas" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestVmas --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestVmas --error-for-skips coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_vmas/setup_env.sh b/.github/unittest/linux_libs/scripts_vmas/setup_env.sh similarity index 100% rename from .circleci/unittest/linux_libs/scripts_vmas/setup_env.sh rename to .github/unittest/linux_libs/scripts_vmas/setup_env.sh diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/batch_scripts.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/batch_scripts.sh similarity index 100% rename from .circleci/unittest/linux_olddeps/scripts_gym_0_13/batch_scripts.sh rename to .github/unittest/linux_olddeps/scripts_gym_0_13/batch_scripts.sh diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml similarity index 100% rename from .circleci/unittest/linux_olddeps/scripts_gym_0_13/environment.yml rename to .github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh similarity index 96% rename from .circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh rename to .github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index fc29520cb85..f55daf8e8ce 100755 --- a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -46,7 +46,7 @@ fi pip install -U --force-reinstall charset-normalizer # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/run-clang-format.py b/.github/unittest/linux_olddeps/scripts_gym_0_13/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh new file mode 100755 index 00000000000..d0ca7e3e46d --- /dev/null +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/work/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/pytorch/rl/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin +export MKL_THREADING_LAYER=GNU + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control' + +export DISPLAY=':99.0' +Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & +CKPT_BACKEND=torch MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py +#pytest --instafail -v --durations 200 +#python test/test_libs.py +coverage combine +coverage xml -i diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh similarity index 97% rename from .circleci/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh rename to .github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh index 0b5dc97719d..9e360c4b9c4 100755 --- a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh @@ -57,7 +57,7 @@ mkdir -p mujoco_py/binaries/linux \ && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ && tar -xf mujoco.tar.gz -C mujoco_py/binaries/linux \ && rm mujoco.tar.gz -wget https://www.roboti.us/file/mjkey.txt +wget https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/mjkey.txt cp mjkey.txt mujoco_py/binaries/ pip install -e . cd .. diff --git a/.circleci/unittest/linux_optdeps/scripts/10_nvidia.json b/.github/unittest/linux_optdeps/scripts/10_nvidia.json similarity index 100% rename from .circleci/unittest/linux_optdeps/scripts/10_nvidia.json rename to .github/unittest/linux_optdeps/scripts/10_nvidia.json diff --git a/.circleci/unittest/linux_optdeps/scripts/environment.yml b/.github/unittest/linux_optdeps/scripts/environment.yml similarity index 100% rename from .circleci/unittest/linux_optdeps/scripts/environment.yml rename to .github/unittest/linux_optdeps/scripts/environment.yml diff --git a/.circleci/unittest/linux_optdeps/scripts/install.sh b/.github/unittest/linux_optdeps/scripts/install.sh similarity index 82% rename from .circleci/unittest/linux_optdeps/scripts/install.sh rename to .github/unittest/linux_optdeps/scripts/install.sh index 6a4cb8b0732..e7d48b4cb9b 100755 --- a/.circleci/unittest/linux_optdeps/scripts/install.sh +++ b/.github/unittest/linux_optdeps/scripts/install.sh @@ -20,10 +20,10 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/$CU_VERSION +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict.git +pip install git+https://github.com/pytorch/tensordict.git # smoke test python -c "import functorch" diff --git a/.github/unittest/linux_optdeps/scripts/post_process.sh b/.github/unittest/linux_optdeps/scripts/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_optdeps/scripts/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_optdeps/scripts/run-clang-format.py b/.github/unittest/linux_optdeps/scripts/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_optdeps/scripts/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_optdeps/scripts/run_all.sh b/.github/unittest/linux_optdeps/scripts/run_all.sh similarity index 100% rename from .circleci/unittest/linux_optdeps/scripts/run_all.sh rename to .github/unittest/linux_optdeps/scripts/run_all.sh diff --git a/.circleci/unittest/linux_optdeps/scripts/run_test.sh b/.github/unittest/linux_optdeps/scripts/run_test.sh similarity index 72% rename from .circleci/unittest/linux_optdeps/scripts/run_test.sh rename to .github/unittest/linux_optdeps/scripts/run_test.sh index 5424fd92374..e8f96dd2425 100755 --- a/.circleci/unittest/linux_optdeps/scripts/run_test.sh +++ b/.github/unittest/linux_optdeps/scripts/run_test.sh @@ -16,6 +16,6 @@ root_dir="$(git rev-parse --show-toplevel)" export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch -MUJOCO_GL=egl python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py +MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_optdeps/scripts/setup_env.sh b/.github/unittest/linux_optdeps/scripts/setup_env.sh similarity index 100% rename from .circleci/unittest/linux_optdeps/scripts/setup_env.sh rename to .github/unittest/linux_optdeps/scripts/setup_env.sh diff --git a/.circleci/unittest/windows_optdepts/scripts/environment.yml b/.github/unittest/windows_optdepts/scripts/environment.yml similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/environment.yml rename to .github/unittest/windows_optdepts/scripts/environment.yml diff --git a/.circleci/unittest/windows_optdepts/scripts/install.sh b/.github/unittest/windows_optdepts/scripts/install.sh similarity index 97% rename from .circleci/unittest/windows_optdepts/scripts/install.sh rename to .github/unittest/windows_optdepts/scripts/install.sh index 55c536ac729..565535a2f1e 100644 --- a/.circleci/unittest/windows_optdepts/scripts/install.sh +++ b/.github/unittest/windows_optdepts/scripts/install.sh @@ -57,7 +57,7 @@ fi #python -m pip install pip --upgrade # install tensordict -pip3 install git+https://github.com/pytorch-labs/tensordict +pip3 install git+https://github.com/pytorch/tensordict # smoke test python -c """ diff --git a/.circleci/unittest/windows_optdepts/scripts/install_conda.bat b/.github/unittest/windows_optdepts/scripts/install_conda.bat similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/install_conda.bat rename to .github/unittest/windows_optdepts/scripts/install_conda.bat diff --git a/.circleci/unittest/windows_optdepts/scripts/post_process.sh b/.github/unittest/windows_optdepts/scripts/post_process.sh similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/post_process.sh rename to .github/unittest/windows_optdepts/scripts/post_process.sh diff --git a/.circleci/unittest/windows_optdepts/scripts/run_test.sh b/.github/unittest/windows_optdepts/scripts/run_test.sh similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/run_test.sh rename to .github/unittest/windows_optdepts/scripts/run_test.sh diff --git a/.circleci/unittest/windows_optdepts/scripts/set_cuda_envs.sh b/.github/unittest/windows_optdepts/scripts/set_cuda_envs.sh similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/set_cuda_envs.sh rename to .github/unittest/windows_optdepts/scripts/set_cuda_envs.sh diff --git a/.circleci/unittest/windows_optdepts/scripts/setup_env.sh b/.github/unittest/windows_optdepts/scripts/setup_env.sh similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/setup_env.sh rename to .github/unittest/windows_optdepts/scripts/setup_env.sh diff --git a/.circleci/unittest/windows_optdepts/scripts/vc_env_helper.bat b/.github/unittest/windows_optdepts/scripts/vc_env_helper.bat similarity index 100% rename from .circleci/unittest/windows_optdepts/scripts/vc_env_helper.bat rename to .github/unittest/windows_optdepts/scripts/vc_env_helper.bat diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 77d695fc76f..1a2384a1df1 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -31,7 +31,7 @@ jobs: - name: Setup Environment run: | python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - python -m pip install git+https://github.com/pytorch-labs/tensordict + python -m pip install git+https://github.com/pytorch/tensordict python setup.py develop python -m pip install pytest pytest-benchmark python -m pip install dm_control @@ -94,7 +94,7 @@ jobs: - name: Setup Environment run: | python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 - python3 -m pip install git+https://github.com/pytorch-labs/tensordict + python3 -m pip install git+https://github.com/pytorch/tensordict python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install dm_control diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 091581cb557..e44c683a6d6 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -30,7 +30,7 @@ jobs: - name: Setup Environment run: | python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - python -m pip install git+https://github.com/pytorch-labs/tensordict + python -m pip install git+https://github.com/pytorch/tensordict python setup.py develop python -m pip install pytest pytest-benchmark python -m pip install dm_control @@ -105,7 +105,7 @@ jobs: - name: Setup Environment run: | python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 - python3 -m pip install git+https://github.com/pytorch-labs/tensordict + python3 -m pip install git+https://github.com/pytorch/tensordict python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install dm_control diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml new file mode 100644 index 00000000000..6ef2cc1ecd0 --- /dev/null +++ b/.github/workflows/build-wheels-m1.yml @@ -0,0 +1,43 @@ +name: Build M1 Wheels + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + workflow_dispatch: + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: macos-arm64 + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build: + needs: generate-matrix + name: pytorch/rl + uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main + with: + repository: pytorch/rl + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: .github/scripts/pre_build_script_m1.sh + post-script: "" + package-name: torchrl + runner-type: macos-m1-12 + smoke-test-script: "" + trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/m1_script.sh + secrets: + AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} + AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 16acc4aa5ac..bc0ae7be205 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -60,7 +60,7 @@ jobs: #pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu118 --quiet --root-user-action=ignore - name: Install tensordict run: | - pip3 install git+https://github.com/pytorch-labs/tensordict.git --quiet --root-user-action=ignore + pip3 install git+https://github.com/pytorch/tensordict.git --quiet --root-user-action=ignore - name: Install TorchRL run: | python3 setup.py develop @@ -88,7 +88,7 @@ jobs: apt-get update && apt-get install -y rsync - name: Pull TensorDict docs run: | - git clone --branch gh-pages https://github.com/pytorch-labs/tensordict.git docs/_local_build/tensordict + git clone --branch gh-pages https://github.com/pytorch/tensordict.git docs/_local_build/tensordict rm -rf docs/_local_build/tensordict/.git - name: Get output time run: echo "The time was ${{ steps.build.outputs.time }}" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e43be123024..c9020a04841 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -66,7 +66,7 @@ jobs: echo '::group::Lint C source' set +e - ./.circleci/unittest/linux/scripts/run-clang-format.py -r torchrl/csrc --clang-format-executable ./clang-format + ./.github/unittest/linux/scripts/run-clang-format.py -r torchrl/csrc --clang-format-executable ./clang-format if [ $? -ne 0 ]; then git --no-pager diff diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 6dd53d93c56..923a3f3dfc1 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -34,7 +34,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]] + python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]] cuda_support: [["", "cpu", "cpu"]] container: pytorch/manylinux-cuda116 steps: @@ -45,7 +45,7 @@ jobs: - name: Install PyTorch nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" - python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda_support[1] }} + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda_support[1] }} - name: Build TorchRL Nightly run: | rm -r dist || true @@ -73,7 +73,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -84,7 +84,7 @@ jobs: uses: actions/checkout@v2 - name: Install PyTorch nightly run: | - python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - name: Build TorchRL Nightly run: | rm -r dist || true @@ -106,7 +106,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -117,7 +117,7 @@ jobs: uses: actions/checkout@v2 - name: Install PyTorch Nightly run: | - python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - name: Upgrade pip run: | python3 -mpip install --upgrade pip @@ -126,7 +126,7 @@ jobs: python3 -mpip install numpy pytest --no-cache-dir - name: Install tensordict run: | - python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git + python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Download built wheels uses: actions/download-artifact@v2 with: @@ -158,7 +158,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]] + python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]] cuda_support: [["", "cpu", "cpu"]] container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: @@ -189,7 +189,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Checkout torchrl uses: actions/checkout@v2 @@ -217,7 +217,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]] + python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]] cuda_support: [["", "cpu", "cpu"]] steps: - name: Setup Python @@ -232,14 +232,14 @@ jobs: - name: Install PyTorch Nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" - python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda_support[1] }} + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda_support[1] }} - name: Upgrade pip run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install --upgrade pip - name: Install tensordict run: | - python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git + python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Install test dependencies run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -279,7 +279,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -290,7 +290,7 @@ jobs: - name: Install PyTorch nightly shell: bash run: | - python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - name: Build TorchRL nightly shell: bash run: | @@ -312,7 +312,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -323,7 +323,7 @@ jobs: - name: Install PyTorch Nightly shell: bash run: | - python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - name: Upgrade pip shell: bash run: | @@ -334,7 +334,7 @@ jobs: python3 -mpip install numpy pytest --no-cache-dir - name: Install tensordict run: | - python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git + python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Download built wheels uses: actions/download-artifact@v2 with: @@ -369,7 +369,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Checkout torchrl uses: actions/checkout@v2 diff --git a/.github/workflows/test-linux-brax.yml b/.github/workflows/test-linux-brax.yml index 0554b5e6ab1..0a09306f313 100644 --- a/.github/workflows/test-linux-brax.yml +++ b/.github/workflows/test-linux-brax.yml @@ -35,4 +35,4 @@ jobs: nvidia-smi - bash .circleci/unittest/linux_libs/scripts_brax/run_all.sh + bash .github/unittest/linux_libs/scripts_brax/run_all.sh diff --git a/.github/workflows/test-linux-cpu.yml b/.github/workflows/test-linux-cpu.yml index 4b34fe8fb8b..a8a349bf478 100644 --- a/.github/workflows/test-linux-cpu.yml +++ b/.github/workflows/test-linux-cpu.yml @@ -22,7 +22,7 @@ jobs: tests: strategy: matrix: - python_version: ["3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10", "3.11"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -34,9 +34,10 @@ jobs: # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} export CU_VERSION="cpu" + export TORCH_VERSION=nightly echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" ## setup_env.sh - bash .circleci/unittest/linux/scripts/run_all.sh + bash .github/unittest/linux/scripts/run_all.sh diff --git a/.github/workflows/test-linux-d4rl.yml b/.github/workflows/test-linux-d4rl.yml index e9e5261d510..a5acce1f5c9 100644 --- a/.github/workflows/test-linux-d4rl.yml +++ b/.github/workflows/test-linux-d4rl.yml @@ -31,7 +31,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 - bash .circleci/unittest/linux_libs/scripts_d4rl/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_d4rl/install.sh - bash .circleci/unittest/linux_libs/scripts_d4rl/run_test.sh - bash .circleci/unittest/linux_libs/scripts_d4rl/post_process.sh + bash .github/unittest/linux_libs/scripts_d4rl/setup_env.sh + bash .github/unittest/linux_libs/scripts_d4rl/install.sh + bash .github/unittest/linux_libs/scripts_d4rl/run_test.sh + bash .github/unittest/linux_libs/scripts_d4rl/post_process.sh diff --git a/.github/workflows/test-linux-envpool.yml b/.github/workflows/test-linux-envpool.yml index de78feb8a44..3b1072c9395 100644 --- a/.github/workflows/test-linux-envpool.yml +++ b/.github/workflows/test-linux-envpool.yml @@ -30,7 +30,7 @@ jobs: nvidia-smi - bash .circleci/unittest/linux_libs/scripts_envpool/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_envpool/install.sh - bash .circleci/unittest/linux_libs/scripts_envpool/run_test.sh - bash .circleci/unittest/linux_libs/scripts_envpool/post_process.sh + bash .github/unittest/linux_libs/scripts_envpool/setup_env.sh + bash .github/unittest/linux_libs/scripts_envpool/install.sh + bash .github/unittest/linux_libs/scripts_envpool/run_test.sh + bash .github/unittest/linux_libs/scripts_envpool/post_process.sh diff --git a/.github/workflows/test-linux-examples.yml b/.github/workflows/test-linux-examples.yml index 979f047d100..60c64510cb3 100644 --- a/.github/workflows/test-linux-examples.yml +++ b/.github/workflows/test-linux-examples.yml @@ -46,4 +46,4 @@ jobs: echo "CU_VERSION: $CU_VERSION" ## setup_env.sh - bash .circleci/unittest/linux_examples/scripts/run_all.sh + bash .github/unittest/linux_examples/scripts/run_all.sh diff --git a/.github/workflows/test-linux-gpu.yml b/.github/workflows/test-linux-gpu.yml index d576f813e4c..594e6c3c096 100644 --- a/.github/workflows/test-linux-gpu.yml +++ b/.github/workflows/test-linux-gpu.yml @@ -22,8 +22,8 @@ jobs: tests: strategy: matrix: - python_version: ["3.9"] # "3.8", "3.9", "3.10", "3.11" - cuda_arch_version: ["12.1"] # "11.6", "11.7" + python_version: ["3.8"] + cuda_arch_version: ["12.1"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -39,6 +39,7 @@ jobs: # Commenting these out for now because the GPU test are not working inside docker export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" + export TORCH_VERSION=nightly # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" @@ -46,4 +47,4 @@ jobs: echo "CU_VERSION: $CU_VERSION" ## setup_env.sh - bash .circleci/unittest/linux/scripts/run_all.sh + bash .github/unittest/linux/scripts/run_all.sh diff --git a/.github/workflows/test-linux-gym.yml b/.github/workflows/test-linux-gym.yml index 7dce5ca0e52..0345955808f 100644 --- a/.github/workflows/test-linux-gym.yml +++ b/.github/workflows/test-linux-gym.yml @@ -33,6 +33,6 @@ jobs: export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/work/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin" export TAR_OPTIONS="--no-same-owner" - ./.circleci/unittest/linux_libs/scripts_gym/setup_env.sh - ./.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh - ./.circleci/unittest/linux_libs/scripts_gym/post_process.sh + ./.github/unittest/linux_libs/scripts_gym/setup_env.sh + ./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh + ./.github/unittest/linux_libs/scripts_gym/post_process.sh diff --git a/.github/workflows/test-linux-habitat.yml b/.github/workflows/test-linux-habitat.yml index 1a2ab19d062..734052241d6 100644 --- a/.github/workflows/test-linux-habitat.yml +++ b/.github/workflows/test-linux-habitat.yml @@ -39,4 +39,4 @@ jobs: # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" - bash .circleci/unittest/linux_libs/scripts_habitat/run_all.sh + bash .github/unittest/linux_libs/scripts_habitat/run_all.sh diff --git a/.github/workflows/test-linux-jumanji.yml b/.github/workflows/test-linux-jumanji.yml index 56782eb52d2..a1ca1eb6a41 100644 --- a/.github/workflows/test-linux-jumanji.yml +++ b/.github/workflows/test-linux-jumanji.yml @@ -34,7 +34,7 @@ jobs: nvidia-smi - bash .circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_jumanji/install.sh - bash .circleci/unittest/linux_libs/scripts_jumanji/run_test.sh - bash .circleci/unittest/linux_libs/scripts_jumanji/post_process.sh + bash .github/unittest/linux_libs/scripts_jumanji/setup_env.sh + bash .github/unittest/linux_libs/scripts_jumanji/install.sh + bash .github/unittest/linux_libs/scripts_jumanji/run_test.sh + bash .github/unittest/linux_libs/scripts_jumanji/post_process.sh diff --git a/.github/workflows/test-linux-olddeps.yml b/.github/workflows/test-linux-olddeps.yml new file mode 100644 index 00000000000..9f54d9dda25 --- /dev/null +++ b/.github/workflows/test-linux-olddeps.yml @@ -0,0 +1,33 @@ +name: Olddeps Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +jobs: + unittests: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + # gpu-arch-type: cuda + # gpu-arch-version: "11.7" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu116" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + + bash .github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh + bash .github/unittest/linux_olddeps/scripts_gym_0_13/batch_scripts.sh + bash .github/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh diff --git a/.github/workflows/test-linux-optdeps.yml b/.github/workflows/test-linux-optdeps.yml index 0a2999a91e0..79aeb552472 100644 --- a/.github/workflows/test-linux-optdeps.yml +++ b/.github/workflows/test-linux-optdeps.yml @@ -43,4 +43,4 @@ jobs: echo "CU_VERSION: $CU_VERSION" ## setup_env.sh - bash .circleci/unittest/linux_optdeps/scripts/run_all.sh + bash .github/unittest/linux_optdeps/scripts/run_all.sh diff --git a/.github/workflows/test-linux-pettingzoo.yml b/.github/workflows/test-linux-pettingzoo.yml new file mode 100644 index 00000000000..628be74beef --- /dev/null +++ b/.github/workflows/test-linux-pettingzoo.yml @@ -0,0 +1,40 @@ +name: PettingZoo Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="11.7" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/install.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/run_test.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/post_process.sh diff --git a/.github/workflows/test-linux-rlhf.yml b/.github/workflows/test-linux-rlhf.yml index 5557d066d2c..86040ae9679 100644 --- a/.github/workflows/test-linux-rlhf.yml +++ b/.github/workflows/test-linux-rlhf.yml @@ -33,7 +33,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 - bash .circleci/unittest/linux_libs/scripts_rlhf/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_rlhf/install.sh - bash .circleci/unittest/linux_libs/scripts_rlhf/run_test.sh - bash .circleci/unittest/linux_libs/scripts_rlhf/post_process.sh + bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh + bash .github/unittest/linux_libs/scripts_rlhf/install.sh + bash .github/unittest/linux_libs/scripts_rlhf/run_test.sh + bash .github/unittest/linux_libs/scripts_rlhf/post_process.sh diff --git a/.github/workflows/test-linux-robohive.yml b/.github/workflows/test-linux-robohive.yml new file mode 100644 index 00000000000..4793971d4a4 --- /dev/null +++ b/.github/workflows/test-linux-robohive.yml @@ -0,0 +1,30 @@ +name: Robohive Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +jobs: + unittests: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.8" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_robohive/setup_env.sh + bash .github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh + bash .github/unittest/linux_libs/scripts_robohive/post_process.sh diff --git a/.github/workflows/test-linux-sklearn.yml b/.github/workflows/test-linux-sklearn.yml index 53422d92115..9ad10a53297 100644 --- a/.github/workflows/test-linux-sklearn.yml +++ b/.github/workflows/test-linux-sklearn.yml @@ -33,7 +33,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 - bash .circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_sklearn/install.sh - bash .circleci/unittest/linux_libs/scripts_sklearn/run_test.sh - bash .circleci/unittest/linux_libs/scripts_sklearn/post_process.sh + bash .github/unittest/linux_libs/scripts_sklearn/setup_env.sh + bash .github/unittest/linux_libs/scripts_sklearn/install.sh + bash .github/unittest/linux_libs/scripts_sklearn/run_test.sh + bash .github/unittest/linux_libs/scripts_sklearn/post_process.sh diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml new file mode 100644 index 00000000000..159c93fb1a1 --- /dev/null +++ b/.github/workflows/test-linux-smacv2.yml @@ -0,0 +1,41 @@ +name: SMACv2 Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="11.7" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_smacv2/setup_env.sh + bash .github/unittest/linux_libs/scripts_smacv2/install.sh + bash .github/unittest/linux_libs/scripts_smacv2/run_test.sh + bash .github/unittest/linux_libs/scripts_smacv2/post_process.sh diff --git a/.github/workflows/test-linux-stable-gpu.yml b/.github/workflows/test-linux-stable-gpu.yml new file mode 100644 index 00000000000..0325940df20 --- /dev/null +++ b/.github/workflows/test-linux-stable-gpu.yml @@ -0,0 +1,50 @@ +name: Unit-tests on Linux GPU, latest stable release + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +env: + CHANNEL: "nightly" + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + tests: + strategy: + matrix: + python_version: ["3.8"] # "3.8", "3.9", "3.10", "3.11" + cuda_arch_version: ["11.8"] # "11.6", "11.7" + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + repository: pytorch/rl + docker-image: "nvidia/cuda:12.1.0-devel-ubuntu22.04" + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 90 + script: | + # Set env vars from matrix + export PYTHON_VERSION=${{ matrix.python_version }} + # Commenting these out for now because the GPU test are not working inside docker + export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} + export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" + export TORCH_VERSION=stable + # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines + #export CU_VERSION="cpu" + + echo "PYTHON_VERSION: $PYTHON_VERSION" + echo "CU_VERSION: $CU_VERSION" + + ## setup_env.sh + bash .github/unittest/linux/scripts/run_all.sh diff --git a/.github/workflows/test-linux-vmas.yml b/.github/workflows/test-linux-vmas.yml index c345c49572c..fc189b28f7f 100644 --- a/.github/workflows/test-linux-vmas.yml +++ b/.github/workflows/test-linux-vmas.yml @@ -34,7 +34,7 @@ jobs: nvidia-smi - bash .circleci/unittest/linux_libs/scripts_vmas/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_vmas/install.sh - bash .circleci/unittest/linux_libs/scripts_vmas/run_test.sh - bash .circleci/unittest/linux_libs/scripts_vmas/post_process.sh + bash .github/unittest/linux_libs/scripts_vmas/setup_env.sh + bash .github/unittest/linux_libs/scripts_vmas/install.sh + bash .github/unittest/linux_libs/scripts_vmas/run_test.sh + bash .github/unittest/linux_libs/scripts_vmas/post_process.sh diff --git a/.github/workflows/test-macos-cpu.yml b/.github/workflows/test-macos-cpu.yml index 6269a5554f3..184cb7e9884 100644 --- a/.github/workflows/test-macos-cpu.yml +++ b/.github/workflows/test-macos-cpu.yml @@ -22,7 +22,7 @@ jobs: tests: strategy: matrix: - python_version: ["3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10", "3.11"] fail-fast: false uses: pytorch/test-infra/.github/workflows/macos_job.yml@main with: @@ -32,9 +32,11 @@ jobs: # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} export CU_VERSION="cpu" + export SYSTEM_VERSION_COMPAT=0 + export TORCH_VERSION=nightly echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" ## setup_env.sh - ./.circleci/unittest/linux/scripts/run_all.sh + ./.github/unittest/linux/scripts/run_all.sh diff --git a/.github/workflows/test-windows-optdepts-cpu.yml b/.github/workflows/test-windows-optdepts-cpu.yml index 70197687758..1cd161a84fb 100644 --- a/.github/workflows/test-windows-optdepts-cpu.yml +++ b/.github/workflows/test-windows-optdepts-cpu.yml @@ -35,13 +35,13 @@ jobs: echo "CU_VERSION: $CU_VERSION" ## setup_env.sh - ./.circleci/unittest/windows_optdepts/scripts/setup_env.sh + ./.github/unittest/windows_optdepts/scripts/setup_env.sh ## install.sh - ./.circleci/unittest/windows_optdepts/scripts/install.sh + ./.github/unittest/windows_optdepts/scripts/install.sh ## run_test.sh - ./.circleci/unittest/windows_optdepts/scripts/run_test.sh + ./.github/unittest/windows_optdepts/scripts/run_test.sh ## post_process.sh - ./.circleci/unittest/windows_optdepts/scripts/post_process.sh + ./.github/unittest/windows_optdepts/scripts/post_process.sh diff --git a/.github/workflows/test-windows-optdepts-gpu.yml b/.github/workflows/test-windows-optdepts-gpu.yml index cdd455eb6e6..652e816fd96 100644 --- a/.github/workflows/test-windows-optdepts-gpu.yml +++ b/.github/workflows/test-windows-optdepts-gpu.yml @@ -17,7 +17,7 @@ jobs: unittests: uses: pytorch/test-infra/.github/workflows/windows_job.yml@main with: - runner: "windows.8xlarge.nvidia.gpu" + runner: "windows.g5.4xlarge.nvidia.gpu" repository: pytorch/rl timeout: 240 script: | @@ -33,7 +33,7 @@ jobs: echo "PYTHON_VERSION: $PYTHON_VERSION" ## setup_env.sh - ./.circleci/unittest/windows_optdepts/scripts/setup_env.sh + ./.github/unittest/windows_optdepts/scripts/setup_env.sh ## Install CUDA packaging/windows/internal/cuda_install.bat @@ -42,10 +42,10 @@ jobs: packaging/windows/internal/driver_update.bat ## install.sh - ./.circleci/unittest/windows_optdepts/scripts/install.sh + ./.github/unittest/windows_optdepts/scripts/install.sh ## run_test.sh - ./.circleci/unittest/windows_optdepts/scripts/run_test.sh + ./.github/unittest/windows_optdepts/scripts/run_test.sh ## post_process.sh - ./.circleci/unittest/windows_optdepts/scripts/post_process.sh + ./.github/unittest/windows_optdepts/scripts/post_process.sh diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 2be76adc994..302c0350c6f 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -4,7 +4,7 @@ on: types: [opened, synchronize, reopened] push: branches: - - release/0.1.1 + - release/0.2.0 concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. @@ -18,8 +18,8 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]] - cuda_support: [["", "--extra-index-url https://download.pytorch.org/whl/cpu", "\"['cpu', '11.3', '11.6']\"", "cpu"]] + python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]] + cuda_support: [["", "--index-url https://download.pytorch.org/whl/cpu", "\"['cpu', '11.3', '11.6']\"", "cpu"]] container: pytorch/manylinux-${{ matrix.cuda_support[3] }} steps: - name: Checkout torchrl @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.1.1 python3 setup.py bdist_wheel + BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -56,7 +56,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -67,12 +67,12 @@ jobs: uses: actions/checkout@v2 - name: Install PyTorch RC run: | - python3 -mpip install torch --extra-index-url https://download.pytorch.org/whl/cpu + python3 -mpip install torch --index-url https://download.pytorch.org/whl/cpu - name: Build wheel run: | export CC=clang CXX=clang++ python3 -mpip install wheel - BUILD_VERSION=0.1.1 python3 setup.py bdist_wheel + BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: @@ -88,7 +88,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -99,12 +99,12 @@ jobs: - name: Install PyTorch RC shell: bash run: | - python3 -mpip install torch --extra-index-url https://download.pytorch.org/whl/cpu + python3 -mpip install torch --index-url https://download.pytorch.org/whl/cpu - name: Build wheel shell: bash run: | python3 -mpip install wheel - BUILD_VERSION=0.1.1 python3 setup.py bdist_wheel + BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: @@ -122,7 +122,7 @@ jobs: strategy: matrix: os: [["linux", "ubuntu-20.04"], ["mac", "macos-latest"]] - python_version: [ "3.8", "3.9", "3.10" ] + python_version: [ "3.8", "3.9", "3.10", "3.11" ] runs-on: ${{ matrix.os[1] }} steps: - name: Setup Python @@ -134,13 +134,13 @@ jobs: uses: actions/checkout@v2 - name: Install PyTorch RC run: | - python3 -mpip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python3 -mpip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - name: Upgrade pip run: | python3 -mpip install --upgrade pip - name: Install tensordict run: | - python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git + python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Install test dependencies run: | python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml @@ -172,7 +172,7 @@ jobs: needs: build-wheel-windows strategy: matrix: - python_version: [ "3.8", "3.9", "3.10" ] + python_version: [ "3.8", "3.9", "3.10", "3.11" ] runs-on: windows-latest steps: - name: Setup Python @@ -184,7 +184,7 @@ jobs: - name: Install PyTorch RC shell: bash run: | - python3 -mpip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python3 -mpip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - name: Upgrade pip shell: bash run: | @@ -192,7 +192,7 @@ jobs: - name: Install tensordict shell: bash run: | - python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git + python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Install test dependencies shell: bash run: | diff --git a/.gitignore b/.gitignore index 6ce8f3e06d0..563891baa98 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,11 @@ instance/ # Sphinx documentation docs/_build/ +docs/build/ +docs/source/gen_modules +docs/source/reference/generated +docs/source/tutorials +docs/src # PyBuilder .pybuilder/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d2f0de2a33..971ea8516dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: additional_dependencies: - flake8-bugbear==22.10.27 - flake8-comprehensions==3.10.1 - + - torchfix==0.0.2 - repo: https://github.com/PyCQA/pydocstyle rev: 6.1.1 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 843a9712369..9f532397241 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,7 +15,7 @@ pip install tensordict-nightly ``` or the git version of the library: ``` -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch/tensordict ``` Once cloned, make sure you install torchrl in develop mode by running diff --git a/README.md b/README.md index 70d474ff78d..9220fdbcd10 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ -[![pytorch](https://circleci.com/gh/pytorch/rl.svg?style=shield)](https://circleci.com/gh/pytorch/rl) +[![Unit-tests](https://github.com/pytorch/rl/actions/workflows/test-linux-gpu.yml/badge.svg)](https://github.com/pytorch/rl/actions/workflows/test-linux-gpu.yml) [![Documentation](https://img.shields.io/badge/Documentation-blue.svg)](https://pytorch.org/rl/) [![Benchmarks](https://img.shields.io/badge/Benchmarks-blue.svg)](https://pytorch.github.io/rl/dev/bench/) [![codecov](https://codecov.io/gh/pytorch/rl/branch/main/graph/badge.svg?token=HcpK1ILV6r)](https://codecov.io/gh/pytorch/rl) [![Twitter Follow](https://img.shields.io/twitter/follow/torchrl1?style=social)](https://twitter.com/torchrl1) -[![Python 3.7, 3.8](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://www.python.org/downloads/) +[![Python version](https://img.shields.io/pypi/pyversions/torchrl.svg)](https://www.python.org/downloads/) [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/pytorch/rl/blob/main/LICENSE) pypi version pypi nightly version @@ -50,7 +50,7 @@ We have some introductory videos for you to get to know the library better, chec RL algorithms are very heterogeneous, and it can be hard to recycle a codebase across settings (e.g. from online to offline, from state-based to pixel-based learning). -TorchRL solves this problem through [`TensorDict`](https://github.com/pytorch-labs/tensordict/), +TorchRL solves this problem through [`TensorDict`](https://github.com/pytorch/tensordict/), a convenient data structure(1) that can be used to streamline one's RL codebase. With this tool, one can write a *complete PPO training script in less than 100 @@ -219,7 +219,7 @@ to be easily recycled across settings. ``` -TensorDict comes with a dedicated [`tensordict.nn`](https://pytorch-labs.github.io/tensordict/reference/nn.html) +TensorDict comes with a dedicated [`tensordict.nn`](https://pytorch.github.io/tensordict/reference/nn.html) module that contains everything you might need to write your model with it. And it is `functorch` and `torch.compile` compatible! @@ -256,7 +256,7 @@ And it is `functorch` and `torch.compile` compatible! ``` - Check [TensorDict tutorials](https://pytorch-labs.github.io/tensordict/) to + Check [TensorDict tutorials](https://pytorch.github.io/tensordict/) to learn more! @@ -265,7 +265,7 @@ And it is `functorch` and `torch.compile` compatible! - A common [interface for environments](torchrl/envs) which supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution (e.g. Model-based environments). - The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution(2). + The [batched environments](torchrl/envs/batched_envs.py) containers allow parallel execution(2). A common PyTorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided. TorchRL's environments API is simple but stringent and specific. Check the [documentation](https://pytorch.org/rl/reference/envs.html) @@ -384,7 +384,7 @@ And it is `functorch` and `torch.compile` compatible! ``` -- various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch-labs/tensordict/blob/main/tensordict/memmap.py))(2); +- various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py))(2); - various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/tensordict_module/actors.py))(1):
Code @@ -470,7 +470,7 @@ And it is `functorch` and `torch.compile` compatible! ### Advantage computation ```python from torchrl.objectives.value.functional import vec_td_lambda_return_estimate - advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done) + advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated) ```
@@ -493,12 +493,15 @@ A series of [examples](examples/) are provided with an illustrative purpose: - [DQN and Rainbow](examples/dqn/dqn.py) - [DDPG](examples/ddpg/ddpg.py) - [IQL](examples/iql/iql.py) +- [CQL](examples/iql/cql.py) - [TD3](examples/td3/td3.py) - [A2C](examples/a2c_old/a2c.py) - [PPO](examples/ppo/ppo.py) - [SAC](examples/sac/sac.py) - [REDQ](examples/redq/redq.py) - [Dreamer](examples/dreamer/dreamer.py) +- [Decision Transformers](examples/decision_transformer) +- [RLHF](examples/rlhf) and many more to come! diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py index d786cc4244d..7f320ff2e8d 100644 --- a/benchmarks/conftest.py +++ b/benchmarks/conftest.py @@ -57,7 +57,7 @@ def pytest_addoption(parser): parser.addoption("--rank", action="store") -@pytest.fixture(autouse=True) +@pytest.fixture(scope="session", autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -69,3 +69,23 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Deprecated call to `pkg_resources.declare_namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Using or importing the ABCs", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Please use `coo_matrix` from the `scipy.sparse` namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated", + ) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py new file mode 100644 index 00000000000..146d011442d --- /dev/null +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""This script executes some envs across the Gym library with the explicit scope of testing the throughput using the various TorchRL components. + +We test: +- gym async envs embedded in a TorchRL's GymEnv wrapper, +- ParallelEnv with regular GymEnv instances, +- Data collector +- Multiprocessed data collectors with parallel envs. + +The tests are executed with various number of cpus, and on different devices. + +""" +import time + +import myosuite # noqa: F401 +import tqdm +from torchrl._utils import timeit +from torchrl.collectors import ( + MultiaSyncDataCollector, + MultiSyncDataCollector, + RandomPolicy, + SyncDataCollector, +) +from torchrl.envs import EnvCreator, GymEnv, ParallelEnv +from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend + +if __name__ == "__main__": + for envname in [ + "CartPole-v1", + "HalfCheetah-v4", + "myoHandReachRandom-v0", + "ALE/Breakout-v5", + ]: + # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes + for num_workers, num_collectors in zip((32, 64, 8, 16), (8, 8, 2, 4)): + with open(f"{envname}_{num_workers}.txt".replace("/", "-"), "w+") as log: + if "myo" in envname: + gym_backend = "gym" + else: + gym_backend = "gymnasium" + + total_frames = num_workers * 10_000 + + # pure gym + def make(envname=envname, gym_backend=gym_backend): + with set_gym_backend(gym_backend): + return gym_bc().make(envname) + + with set_gym_backend(gym_backend): + env = gym_bc().vector.AsyncVectorEnv( + [make for _ in range(num_workers)] + ) + env.reset() + global_step = 0 + times = [] + start = time.time() + print("Timer started.") + for _ in tqdm.tqdm(range(total_frames // num_workers)): + env.step(env.action_space.sample()) + global_step += num_workers + env.close() + log.write( + f"pure gym: {num_workers * 10_000 / (time.time() - start): 4.4f} fps\n" + ) + log.flush() + + # regular parallel env + for device in ( + "cuda:0", + "cpu", + ): + + def make(envname=envname, gym_backend=gym_backend, device=device): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + env_make = EnvCreator(make) + penv = ParallelEnv(num_workers, env_make) + # warmup + penv.rollout(2) + pbar = tqdm.tqdm(total=num_workers * 10_000) + t0 = time.time() + for _ in range(100): + data = penv.rollout(100, break_when_any_done=False) + pbar.update(100 * num_workers) + log.write( + f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + penv.close() + timeit.print() + del penv + + for device in ("cuda:0", "cpu"): + + def make(envname=envname, gym_backend=gym_backend, device=device): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + env_make = EnvCreator(make) + # penv = SerialEnv(num_workers, env_make) + penv = ParallelEnv(num_workers, env_make) + collector = SyncDataCollector( + penv, + RandomPolicy(penv.action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # gym parallel env + def make_env( + envname=envname, + num_workers=num_workers, + gym_backend=gym_backend, + device=device, + ): + with set_gym_backend(gym_backend): + penv = GymEnv(envname, num_envs=num_workers, device=device) + return penv + + penv = make_env() + # warmup + penv.rollout(2) + pbar = tqdm.tqdm(total=num_workers * 10_000) + t0 = time.time() + for _ in range(100): + data = penv.rollout(100, break_when_any_done=False) + pbar.update(100 * num_workers) + log.write( + f"gym penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + penv.close() + del penv + + for device in ( + "cuda:0", + "cpu", + ): + # async collector + # + torchrl parallel env + def make_env( + envname=envname, gym_backend=gym_backend, device=device + ): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + penv = ParallelEnv( + num_workers // num_collectors, EnvCreator(make_env) + ) + collector = MultiaSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv.action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"async collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # async collector + # + gym async env + def make_env( + envname=envname, + num_workers=num_workers, + gym_backend=gym_backend, + device=device, + ): + with set_gym_backend(gym_backend): + penv = GymEnv(envname, num_envs=num_workers, device=device) + return penv + + penv = EnvCreator( + lambda num_workers=num_workers // num_collectors: make_env( + num_workers=num_workers + ) + ) + collector = MultiaSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv().action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + num_sub_threads=num_workers // num_collectors, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"async collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # sync collector + # + torchrl parallel env + def make_env( + envname=envname, gym_backend=gym_backend, device=device + ): + with set_gym_backend(gym_backend): + return GymEnv(envname, device=device) + + penv = ParallelEnv( + num_workers // num_collectors, EnvCreator(make_env) + ) + collector = MultiSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv.action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"sync collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + + for device in ( + "cuda:0", + "cpu", + ): + # sync collector + # + gym async env + def make_env( + envname=envname, + num_workers=num_workers, + gym_backend=gym_backend, + device=device, + ): + with set_gym_backend(gym_backend): + penv = GymEnv(envname, num_envs=num_workers, device=device) + return penv + + penv = EnvCreator( + lambda num_workers=num_workers // num_collectors: make_env( + num_workers=num_workers + ) + ) + collector = MultiSyncDataCollector( + [penv] * num_collectors, + policy=RandomPolicy(penv().action_spec), + frames_per_batch=1024, + total_frames=num_workers * 10_000, + num_sub_threads=num_workers // num_collectors, + device=device, + ) + pbar = tqdm.tqdm(total=num_workers * 10_000) + total_frames = 0 + for i, data in enumerate(collector): + if i == num_collectors: + t0 = time.time() + if i >= num_collectors: + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + log.write( + f"sync collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" + ) + log.flush() + collector.shutdown() + del collector + exit() diff --git a/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py b/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py new file mode 100644 index 00000000000..daaf800353f --- /dev/null +++ b/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import pickle + +import time +from pathlib import Path +from typing import Dict + +import numpy as np + +import ray + +import vmas +from matplotlib import pyplot as plt +from ray import tune + +from ray.rllib.agents.ppo import PPOTrainer +from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.tune import register_env +from torchrl.collectors import SyncDataCollector +from torchrl.envs.libs.vmas import VmasEnv +from vmas import Wrapper + + +def store_pickled_evaluation(name: str, evaluation: dict): + save_folder = f"{os.path.dirname(os.path.realpath(__file__))}" + file = f"{save_folder}/{name}.pkl" + + pickle.dump(evaluation, open(file, "wb")) + + +def load_pickled_evaluation( + name: str, +): + save_folder = f"{os.path.dirname(os.path.realpath(__file__))}" + file = Path(f"{save_folder}/{name}.pkl") + + if file.is_file(): + return pickle.load(open(file, "rb")) + return None + + +def run_vmas_torchrl( + scenario_name: str, n_envs: int, n_steps: int, device: str, seed: int = 0 +): + env = VmasEnv( + scenario_name, + device=device, + num_envs=n_envs, + continuous_actions=False, + seed=seed, + ) + + collector = SyncDataCollector( + env, + policy=None, + device=device, + frames_per_batch=n_envs * n_steps, + total_frames=n_envs * n_steps, + ) + + init_time = time.time() + + for _data in collector: + pass + + total_time = time.time() - init_time + collector.shutdown() + return total_time + + +def run_vmas_rllib( + scenario_name: str, n_envs: int, n_steps: int, device: str, seed: int = 0 +): + class TimerCallback(DefaultCallbacks): + result_time = None + + def on_train_result( + self, + *, + algorithm, + result: dict, + **kwargs, + ) -> None: + TimerCallback.result_time = ( + result["timers"]["training_iteration_time_ms"] + - result["timers"]["learn_time_ms"] + ) + + def env_creator(config: Dict): + env = vmas.make_env( + scenario=config["scenario_name"], + num_envs=config["num_envs"], + device=config["device"], + continuous_actions=False, + wrapper=Wrapper.RLLIB, + ) + return env + + if not ray.is_initialized(): + ray.init() + register_env(scenario_name, lambda config: env_creator(config)) + + num_gpus = 0.5 if device == "cuda" else 0 + num_gpus_per_worker = 0.5 if device == "cuda" else 0 + tune.run( + PPOTrainer, + stop={"training_iteration": 1}, + config={ + "seed": seed, + "framework": "torch", + "env": scenario_name, + "train_batch_size": n_envs * n_steps, + "rollout_fragment_length": n_steps, + "sgd_minibatch_size": n_envs * n_steps, + "num_gpus": num_gpus, + "num_workers": 0, + "num_gpus_per_worker": num_gpus_per_worker, + "num_envs_per_worker": n_envs, + "batch_mode": "truncate_episodes", + "env_config": { + "device": device, + "num_envs": n_envs, + "scenario_name": scenario_name, + "max_steps": n_steps, + }, + "callbacks": TimerCallback, + }, + ) + assert TimerCallback.result_time is not None + TimerCallback.result_time /= 1_000 # convert to seconds + return TimerCallback.result_time + + +def run_comparison_torchrl_rllib( + scenario_name: str, + device: str, + n_steps: int = 100, + max_n_envs: int = 3000, + step_n_envs: int = 3, +): + """ + + Args: + scenario_name (str): name of scenario to benchmark + device (str): device to ron comparison on ("cpu" or "cuda") + n_steps (int): number of environment steps + max_n_envs (int): the maximum number of parallel environments to test + step_n_envs (int): the step size in number of environments from 1 to max_n_envs + + """ + list_n_envs = np.linspace(1, max_n_envs, step_n_envs) + + figure_name = f"VMAS_{scenario_name}_{n_steps}_{device}_steps_rllib_vs_torchrl" + figure_name_pkl = figure_name + f"_range_{1}_{max_n_envs}_num_{step_n_envs}" + + evaluation = load_pickled_evaluation(figure_name_pkl) + if not evaluation: + evaluation = {} + for framework in ["TorchRL", "RLlib"]: + if framework not in evaluation.keys(): + print(f"\nFramework {framework}") + vmas_times = [] + for n_envs in list_n_envs: + n_envs = int(n_envs) + print(f"Running {n_envs} environments") + if framework == "TorchRL": + vmas_times.append( + (n_envs * n_steps) + / run_vmas_torchrl( + scenario_name=scenario_name, + n_envs=n_envs, + n_steps=n_steps, + device=device, + ) + ) + else: + vmas_times.append( + (n_envs * n_steps) + / run_vmas_rllib( + scenario_name=scenario_name, + n_envs=n_envs, + n_steps=n_steps, + device=device, + ) + ) + print(f"fps {vmas_times[-1]}s") + evaluation[framework] = vmas_times + + store_pickled_evaluation(name=figure_name_pkl, evaluation=evaluation) + + fig, ax = plt.subplots() + for key, item in evaluation.items(): + ax.plot( + list_n_envs, + item, + label=key, + ) + + plt.xlabel("Number of batched environments", fontsize=14) + plt.ylabel("Frames per second", fontsize=14) + ax.legend(loc="upper left") + + ax.set_title( + f"Execution time of '{scenario_name}' for {n_steps} steps on {device}.", + fontsize=8, + ) + + save_folder = os.path.dirname(os.path.realpath(__file__)) + plt.savefig(f"{save_folder}/{figure_name}.pdf") + + +if __name__ == "__main__": + # pip install matplotlib + # pip install "ray[rllib]"==2.1.0 + # pip install torchrl + # pip install vmas + # pip install numpy==1.23.5 + + run_comparison_torchrl_rllib( + scenario_name="simple_spread", + device="cuda", + n_steps=100, + max_n_envs=30000, + step_n_envs=10, + ) diff --git a/benchmarks/test_envs_benchmark.py b/benchmarks/test_envs_benchmark.py index d873b5efb6a..da49cdae5b1 100644 --- a/benchmarks/test_envs_benchmark.py +++ b/benchmarks/test_envs_benchmark.py @@ -118,9 +118,9 @@ def test_step_mdp_speed( benchmark( step_mdp, td, - action_key=action_key, - reward_key=reward_key, - done_key=done_key, + action_keys=action_key, + reward_keys=reward_key, + done_keys=done_key, keep_other=keep_other, exclude_reward=exclude_reward, exclude_done=exclude_done, diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 4fd365afe37..ca5b7eb82ed 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -435,7 +435,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= loss = TD3Loss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), minimum=-1, maximum=1), + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), ) loss(td) diff --git a/docs/requirements.txt b/docs/requirements.txt index 177fa84e9dc..8bb409ff326 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -24,3 +24,4 @@ imageio[ffmpeg,pyav] memory_profiler pyrender pytest +vmas==1.2.11 diff --git a/docs/source/conf.py b/docs/source/conf.py index 497a0df4fdb..00acf6b67ed 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,7 @@ intersphinx_mapping = { "torch": ("https://pytorch.org/docs/stable/", None), - "tensordict": ("https://pytorch-labs.github.io/tensordict/", None), + "tensordict": ("https://pytorch.github.io/tensordict/", None), # "torchrl": ("https://pytorch.org/rl/", None), "torchaudio": ("https://pytorch.org/audio/stable/", None), "torchtext": ("https://pytorch.org/text/stable/", None), diff --git a/docs/source/index.rst b/docs/source/index.rst index 75af9a95f64..91906abb857 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -49,6 +49,7 @@ Intermediate .. toctree:: :maxdepth: 1 + tutorials/multiagent_ppo tutorials/torchrl_envs tutorials/pretrained_models tutorials/dqn_with_rnn diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index d34a266d3db..aa8de179f20 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -105,6 +105,11 @@ node or across multiple nodes. building a parallel environment or collector can result in a slower collection than using ``device="cuda"`` when available. +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. .. currentmodule:: torchrl.collectors.distributed diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 83b1fea1464..cd2b71a0922 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -215,6 +215,8 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec UnboundedDiscreteTensorSpec + LazyStackedTensorSpec + LazyStackedCompositeSpec Reinforcement Learning From Human Feedback (RLHF) ------------------------------------------------- @@ -253,3 +255,6 @@ Utils :template: rl_template.rst MultiStep + consolidate_spec + check_no_exclusive_keys + contains_lazy_spec diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 50519dc85fa..f6a5d24e2f8 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -36,12 +36,12 @@ Each env will have the following attributes: - :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing the reward spec. - :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing - the done-flag spec. + the done-flag spec. See the section on trajectory termination below. - :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing - all the input keys (:obj:`"_action_spec"` and :obj:`"_state_spec"`). + all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). It is locked and should not be modified directly. - :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing - all the output keys (:obj:`"_observation_spec"`, :obj:`"_reward_spec"` and :obj:`"_done_spec"`). + all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). It is locked and should not be modified directly. Importantly, the environment spec shapes should contain the batch size, e.g. @@ -79,22 +79,25 @@ The following figure summarizes how a rollout is executed in torchrl. In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method, then populated with an action by the policy before being passed to the -:meth:`~.EnvBase.step` method which writes the observations, done flag and +:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and reward under the ``"next"`` entry. The result of this call is stored for delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp` function. .. note:: - - The Gym(nasium) API recently shifted to a splitting of the ``"done"`` state - into a ``terminated`` (the env is done and results should not be trusted) - and ``truncated`` (the maximum number of steps is reached) flags. - In TorchRL, ``"done"`` usually refers to ``"terminated"``. Truncation is - achieved via the :class:`~.StepCounter` transform class, and the output - key will be ``"truncated"`` if not chosen to be something else (e.g. - ``StepCounter(max_steps=100, truncated_key="done")``). - TorchRL's collectors and rollout methods will be looking for one of these - keys when assessing if the env should be reset. + In general, all TorchRL environment have a ``"done"`` and ``"terminated"`` + entry in their output tensordict. If they are not present by design, + the :class:`~.EnvBase` metaclass will ensure that every done or terminated + is flanked with its dual. + In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory + signals and should be interpreted as "the last step of a trajectory" or + equivalently "a signal indicating the need to reset". + If the environment provides it (eg, Gymnasium), the truncation entry is also + written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry. + If the environment carries a single value, it will interpreted as a ``"terminated"`` + signal by default. + By default, TorchRL's collectors and rollout methods will be looking for the ``"done"`` + entry to assess if the environment should be reset. .. note:: @@ -136,6 +139,12 @@ environments in parallel. As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment. Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count: +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. + It is important that your environment specs match the input and output that it sends and receives, as :class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check. @@ -166,12 +175,13 @@ It is also possible to reset some but not all of the environments: :caption: Parallel environment reset >>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4]) - >>> env.reset(tensordict) + >>> env.reset(tensordict) # eliminates the "_reset" entry TensorDict( fields={ + terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool), done: Tensor(torch.Size([4, 1]), dtype=torch.bool), pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), - _reset: Tensor(torch.Size([4, 1]), dtype=torch.bool)}, + truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool), batch_size=torch.Size([4]), device=None, is_shared=True) @@ -213,12 +223,13 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w SerialEnv ParallelEnv - MultiThreadedEnv EnvCreator Multi-agent environments ------------------------ +.. currentmodule:: torchrl.envs + TorchRL supports multi-agent learning out-of-the-box. *The same classes used in a single-agent learning pipeline can be seamlessly used in multi-agent contexts, without any modification or dedicated multi-agent infrastructure.* @@ -231,7 +242,7 @@ Some of the main differences between these paradigms include: - **observation** can be per-agent and also have some shared components - **reward** can be per-agent or shared -- **done** can be per-agent or shared +- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared. TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier. In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict. @@ -340,8 +351,17 @@ single agent standards. spec if the accessed spec is Composite. Therefore, if in the example above we run `env.reward_spec` after env creation, we would get the same output as `torch.stack(reward_specs)}`. To get the full composite spec with the "agents" key, you can run - `env.output_spec["_reward_spec"]`. The same is valid for action and done specs. - Note that `env.reward_spec == env.output_spec["_reward_spec"][env.reward_key]`. + `env.output_spec["full_reward_spec"]`. The same is valid for action and done specs. + Note that `env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]`. + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + MarlGroupMapType + check_marl_grouping + Transforms @@ -445,13 +465,18 @@ to be able to create this other composition: Transform TransformedEnv + ActionMask BinarizeReward CatFrames CatTensors CenterCrop + ClipTransform Compose + DeviceCastTransform DiscreteActionProjection DoubleToFloat + DTypeCastTransform + EndOfLifeTransform ExcludeTransform FiniteTensorDictCheck FlattenObservation @@ -463,6 +488,7 @@ to be able to create this other composition: NoopResetEnv ObservationNorm ObservationTransform + PermuteTransform PinMemoryTransform R3MTransform RandomCropTensorDict @@ -480,11 +506,60 @@ to be able to create this other composition: TimeMaxPool ToTensorImage UnsqueezeTransform + VecGymEnvTransform VecNorm VC1Transform VIPRewardTransform VIPTransform +Environments with masked actions +-------------------------------- + +In some environments with discrete actions, the actions available to the agent might change throughout execution. +In such cases the environments will output an action mask (under the ``"action_mask"`` key by default). +This mask needs to be used to filter out unavailable actions for that step. + +If you are using a custom policy you can pass this mask to your probability distribution like so: + +.. code-block:: + :caption: Categorical policy with action mask + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.modules import MaskedCategorical + >>> module = TensorDictModule( + >>> nn.Linear(in_feats, out_feats), + >>> in_keys=["observation"], + >>> out_keys=["logits"], + >>> ) + >>> dist = ProbabilisticTensorDictModule( + >>> in_keys={"logits": "logits", "mask": "action_mask"}, + >>> out_keys=["action"], + >>> distribution_class=MaskedCategorical, + >>> ) + >>> actor = TensorDictSequential(module, dist) + +If you want to use a default policy, you will need to wrap your environment in the :class:`~torchrl.envs.transforms.ActionMask` +transform. This transform can take care of updating the action mask in the action spec in order for the default policy +to always know what the latest available actions are. You can do this like so: + +.. code-block:: + :caption: How to use the action mask transform + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.envs.transforms import TransformedEnv, ActionMask + >>> env = TransformedEnv( + >>> your_base_env + >>> ActionMask(action_key="action", mask_key="action_mask"), + >>> ) + +.. note:: + In case you are using a parallel environment it is important to add the transform to the parallel enviornment itself + and not to its sub-environments. + + + Recorders --------- @@ -516,6 +591,7 @@ Helpers exploration_type check_env_specs make_composite_from_td + terminated_or_truncated Domain-specific --------------- @@ -532,7 +608,7 @@ Domain-specific Libraries --------- -.. currentmodule:: torchrl.envs.libs +.. currentmodule:: torchrl.envs TorchRL's mission is to make the training of control and decision algorithm as easy as it gets, irrespective of the simulator being used (if any). @@ -608,19 +684,28 @@ the following function will return ``1`` when queried: :toctree: generated/ :template: rl_template_fun.rst - brax.BraxEnv - brax.BraxWrapper - dm_control.DMControlEnv - dm_control.DMControlWrapper - gym.GymEnv - gym.GymWrapper - gym.MOGymEnv - gym.MOGymWrapper - gym.set_gym_backend - gym.gym_backend - habitat.HabitatEnv - jumanji.JumanjiEnv - jumanji.JumanjiWrapper - openml.OpenMLEnv - vmas.VmasEnv - vmas.VmasWrapper + BraxEnv + BraxWrapper + DMControlEnv + DMControlWrapper + GymEnv + GymWrapper + HabitatEnv + IsaacGymEnv + IsaacGymWrapper + JumanjiEnv + JumanjiWrapper + MOGymEnv + MOGymWrapper + MultiThreadedEnv + MultiThreadedEnvWrapper + OpenMLEnv + PettingZooEnv + PettingZooWrapper + RoboHiveEnv + SMACv2Env + SMACv2Wrapper + VmasEnv + VmasWrapper + gym_backend + set_gym_backend diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index bb650864c0a..978eb610e60 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -68,7 +68,7 @@ other cases, the action written in the tensordict is simply the network output. :template: rl_template_noinherit.rst AdditiveGaussianWrapper - EGreedyWrapper + EGreedyModule OrnsteinUhlenbeckProcessWrapper Probabilistic actors @@ -261,7 +261,7 @@ without shared parameters. It is mainly intended as a replacement for ActorCriticWrapper ActorValueOperator ValueOperator - + DecisionTransformerInferenceWrapper Domain-specific TensorDict modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -308,6 +308,7 @@ Regular modules MLP ConvNet + Conv3dNet LSTMNet SqueezeLayer Squeeze2dLayer @@ -322,18 +323,22 @@ algorithms, such as DQN, DDPG or Dreamer. :toctree: generated/ :template: rl_template_noinherit.rst - DuelingCnnDQNet - DistributionalDQNnet + DTActor DdpgCnnActor DdpgCnnQNet DdpgMlpActor DdpgMlpQNet + DecisionTransformer + DistributionalDQNnet DreamerActor + DuelingCnnDQNet + GRUModule LSTMModule - ObsEncoder ObsDecoder - RSSMPrior + ObsEncoder + OnlineDTActor RSSMPosterior + RSSMPrior Multi-agent-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -346,6 +351,7 @@ multi-agent contexts. :template: rl_template_noinherit.rst MultiAgentMLP + MultiAgentConvNet QMixer VDNMixer @@ -398,6 +404,7 @@ Some distributions are typically used in RL scripts. TanhDelta OneHotCategorical MaskedCategorical + MaskedOneHotCategorical Utils ----- diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 0e63b67db84..26979e2ae96 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -137,6 +137,16 @@ CQL CQLLoss +DT +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DTLoss + OnlineDTLoss + TD3 ---- diff --git a/examples/a2c/README.md b/examples/a2c/README.md new file mode 100644 index 00000000000..513e6d70811 --- /dev/null +++ b/examples/a2c/README.md @@ -0,0 +1,29 @@ +## Reproducing Advantage Actor Critic (A2C) Algorithm Results + +This repository contains scripts that enable training agents using the Advantage Actor Critic (A2C) Algorithm on MuJoCo and Atari environments. We follow the original paper [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) by Mnih et al. (2016) to implement the A2C algorithm but fix the number of steps during the collection phase. + + +## Examples Structure + +Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. a2c_atari.py). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml). + + +## Running the Examples + +You can execute the A2C algorithm on Atari environments by running the following command: + +```bash +python a2c_atari.py +``` + +You can execute the A2C algorithm on MuJoCo environments by running the following command: + +```bash +python a2c_mujoco.py +``` diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py deleted file mode 100644 index 74d8f6bc81f..00000000000 --- a/examples/a2c/a2c.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""A2C Example. - -This is a self-contained example of a A2C training script. - -Both state and pixel-based environments are supported. - -The helper functions are coded in the utils.py associated with this script. -""" -import hydra - - -@hydra.main(config_path=".", config_name="config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 - - import torch - import tqdm - from utils import ( - make_a2c_models, - make_collector, - make_logger, - make_loss, - make_optim, - make_test_env, - ) - - # Correct for frame_skip - cfg.collector.total_frames = cfg.collector.total_frames // cfg.env.frame_skip - cfg.collector.frames_per_batch = ( - cfg.collector.frames_per_batch // cfg.env.frame_skip - ) - - model_device = cfg.optim.device - actor, critic = make_a2c_models(cfg) - actor = actor.to(model_device) - critic = critic.to(model_device) - - collector = make_collector(cfg, policy=actor) - loss_module, adv_module = make_loss( - cfg.loss, actor_network=actor, value_network=critic - ) - optim = make_optim(cfg.optim, actor_network=actor, value_network=critic) - - batch_size = cfg.collector.total_frames * cfg.env.num_envs - total_network_updates = cfg.collector.total_frames // batch_size - - scheduler = None - if cfg.optim.lr_scheduler: - scheduler = torch.optim.lr_scheduler.LinearLR( - optim, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1 - ) - - logger = None - if cfg.logger.backend: - logger = make_logger(cfg.logger) - test_env = make_test_env(cfg.env) - record_interval = cfg.logger.log_interval - pbar = tqdm.tqdm(total=cfg.collector.total_frames) - collected_frames = 0 - - # Main loop - r0 = None - l0 = None - for data in collector: - - frames_in_batch = data.numel() - collected_frames += frames_in_batch * cfg.env.frame_skip - pbar.update(data.numel()) - data_view = data.reshape(-1) - - # Compute GAE - with torch.no_grad(): - batch = adv_module(data_view) - - # Normalize advantage - adv = batch.get("advantage") - loc = adv.mean().item() - scale = adv.std().clamp_min(1e-6).item() - adv = (adv - loc) / scale - batch.set("advantage", adv) - - # Forward pass A2C loss - batch = batch.to(model_device) - loss = loss_module(batch) - loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - - # Backward pass + learning step - loss_sum.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - list(actor.parameters()) + list(critic.parameters()), max_norm=0.5 - ) - optim.step() - if scheduler is not None: - scheduler.step() - optim.zero_grad() - - # Logging - if r0 is None: - r0 = data["next", "reward"].mean().item() - if l0 is None: - l0 = loss_sum.item() - pbar.set_description( - f"loss: {loss_sum.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['next', 'reward'].mean(): 4.4f} (init={r0: 4.4f})" - ) - if logger is not None: - for key, value in loss.items(): - logger.log_scalar(key, value.item(), collected_frames) - logger.log_scalar("grad_norm", grad_norm.item(), collected_frames) - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - logger.log_scalar( - "reward_training", episode_rewards.mean().item(), collected_frames - ) - collector.update_policy_weights_() - - # Test current policy - if ( - logger is not None - and (collected_frames - frames_in_batch) // record_interval - < collected_frames // record_interval - ): - - with torch.no_grad(): - test_env.eval() - actor.eval() - # Generate a complete episode - td_test = test_env.rollout( - policy=actor, - max_steps=10_000_000, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - ).clone() - logger.log_scalar( - "reward_testing", - td_test["next", "reward"].sum().item(), - collected_frames, - ) - actor.train() - - -if __name__ == "__main__": - main() diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py new file mode 100644 index 00000000000..37c1bd9842d --- /dev/null +++ b/examples/a2c/a2c_atari.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hydra + + +@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_atari import eval_model, make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = ( + actor.to(device), + critic.to(device), + critic_head.to(device), + ) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=True, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + + # use end-of-life as done key + adv_module.set_keys(done="end-of-life", terminated="end-of-life") + loss_module.set_keys(done="end-of-life", terminated="end-of-life") + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, logger_name="a2c", experiment_name=exp_name + ) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + losses = TensorDict({}, batch_size=[num_mini_batches]) + training_start = time.time() + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + batch = batch.to(device) + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + num_network_updates += 1 + + # Forward pass A2C loss + loss = loss_module(batch) + losses[k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "test/reward": test_rewards.mean(), + "test/eval_time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py new file mode 100644 index 00000000000..4192ddc6556 --- /dev/null +++ b/examples/a2c/a2c_mujoco.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hydra + + +@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_mujoco import eval_model, make_env, make_ppo_models + + # Define paper hyperparameters + device = "cpu" if not torch.cuda.device_count() else "cuda" + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size + total_network_updates = ( + cfg.collector.total_frames // cfg.collector.frames_per_batch + ) * num_mini_batches + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device), + sampler=sampler, + batch_size=cfg.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, logger_name="a2c", experiment_name=exp_name + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Get training rewards and lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + losses = TensorDict({}, batch_size=[num_mini_batches]) + training_start = time.time() + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + for group in critic_optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + num_network_updates += 1 + + # Forward pass A2C loss + loss = loss_module(batch) + losses[k] = loss.select( + "loss_critic", "loss_objective" # , "loss_entropy" + ).detach() + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] # + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + # Get training losses + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < ( + i * frames_in_batch + ) // cfg.logger.test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "test/reward": test_rewards.mean(), + "test/eval_time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/a2c/a2c_mujoco_halfcheetah.png b/examples/a2c/a2c_mujoco_halfcheetah.png deleted file mode 100644 index a9cc5deeb2d..00000000000 Binary files a/examples/a2c/a2c_mujoco_halfcheetah.png and /dev/null differ diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml deleted file mode 100644 index 343a59eca6e..00000000000 --- a/examples/a2c/config.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# task and env -env: - env_name: HalfCheetah-v4 - env_task: "" - env_library: gym - frame_skip: 2 - num_envs: 1 - noop: 1 - reward_scaling: 1.0 - from_pixels: False - n_samples_stats: 3 - -# collector -collector: - frames_per_batch: 64 - total_frames: 1_000_000 - collector_device: cuda # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: a2c_halfcheetah_gym - log_interval: 10000 - -# Optim -optim: - device: cuda - lr: 0.0005 - weight_decay: 0.0 - lr_scheduler: False - -# loss -loss: - gamma: 0.99 - gae_lamdda: 0.95 - critic_coef: 0.5 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/a2c/config_atari.yaml b/examples/a2c/config_atari.yaml new file mode 100644 index 00000000000..0b06584ee67 --- /dev/null +++ b/examples/a2c/config_atari.yaml @@ -0,0 +1,33 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + num_envs: 1 + +# collector +collector: + frames_per_batch: 80 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0001 + eps: 1.0e-8 + weight_decay: 0.0 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 80 + gae_lambda: 0.95 + critic_coef: 0.25 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/a2c/config_mujoco.yaml b/examples/a2c/config_mujoco.yaml new file mode 100644 index 00000000000..48627059de9 --- /dev/null +++ b/examples/a2c/config_mujoco.yaml @@ -0,0 +1,30 @@ +# task and env +env: + env_name: HalfCheetah-v3 + +# collector +collector: + frames_per_batch: 64 + total_frames: 1_000_000 + +# logger +logger: + backend: wandb + exp_name: Mujoco_Schulman17 + test_interval: 1_000_000 + num_test_episodes: 5 + +# Optim +optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: False + +# loss +loss: + gamma: 0.99 + mini_batch_size: 64 + gae_lambda: 0.95 + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 diff --git a/examples/a2c/training_curves.md b/examples/a2c/training_curves.md deleted file mode 100644 index 52934b2bea6..00000000000 --- a/examples/a2c/training_curves.md +++ /dev/null @@ -1,7 +0,0 @@ -# A2C Example Results - -## MuJoCo HalfCheetah Environment - -We tested the A2C algorithm on the MuJoCo HalfCheetah environment. The hyperparameters used for the training are specified in the config.yaml file. - -![a2c_mujoco_halfcheetah.png](a2c_mujoco_halfcheetah.png) diff --git a/examples/a2c/utils.py b/examples/a2c/utils.py deleted file mode 100644 index 23e23e07e45..00000000000 --- a/examples/a2c/utils.py +++ /dev/null @@ -1,466 +0,0 @@ -import torch.nn -import torch.optim -from tensordict.nn import TensorDictModule - -from torchrl.collectors import SyncDataCollector -from torchrl.data import CompositeSpec - -from torchrl.data.tensor_specs import DiscreteBox -from torchrl.envs import ( - CatFrames, - CatTensors, - DoubleToFloat, - EnvCreator, - ExplorationType, - GrayScale, - NoopResetEnv, - ObservationNorm, - ParallelEnv, - Resize, - RewardScaling, - RewardSum, - StepCounter, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.modules import ( - ActorValueOperator, - ConvNet, - MLP, - NormalParamWrapper, - OneHotCategorical, - ProbabilisticActor, - TanhNormal, - ValueOperator, -) -from torchrl.objectives import A2CLoss -from torchrl.objectives.value.advantages import GAE -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.envs import get_norm_state_dict, LIBS - -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - - -# ==================================================================== -# Environment utils -# ----------------- - - -def make_base_env(env_cfg, from_pixels=None): - env_library = LIBS[env_cfg.env_library] - env_kwargs = { - "env_name": env_cfg.env_name, - "frame_skip": env_cfg.frame_skip, - "from_pixels": env_cfg.from_pixels - if from_pixels is None - else from_pixels, # for rendering - "pixels_only": False, - } - if env_library is DMControlEnv: - env_task = env_cfg.env_task - env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) - return env - - -def make_transformed_env(base_env, env_cfg): - if env_cfg.noop > 1: - base_env = TransformedEnv(env=base_env, transform=NoopResetEnv(env_cfg.noop)) - from_pixels = env_cfg.from_pixels - if from_pixels: - return make_transformed_env_pixels(base_env, env_cfg) - else: - return make_transformed_env_states(base_env, env_cfg) - - -def make_transformed_env_pixels(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env_library = LIBS[env_cfg.env_library] - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - env.append_transform(RewardScaling(0.0, reward_scaling)) - - double_to_float_list = [] - double_to_float_inv_list = [] - - env.append_transform(ToTensorImage()) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - double_to_float_inv_list += ["action"] # DMControl requires double-precision - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) - return env - - -def make_transformed_env_states(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env_library = LIBS[env_cfg.env_library] - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - - env.append_transform(RewardScaling(0.0, reward_scaling)) - - double_to_float_list = [] - double_to_float_inv_list = [] - - # we concatenate all the state vectors - # even if there is a single tensor, it'll be renamed in "observation_vector" - selected_keys = [ - key for key in env.observation_spec.keys(True, True) if key != "pixels" - ] - out_key = "observation_vector" - env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - # obs_norm = ObservationNorm(in_keys=[out_key]) - # env.append_transform(obs_norm) - - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - double_to_float_inv_list += ["action"] # DMControl requires double-precision - double_to_float_list += ["observation_vector"] - else: - double_to_float_list += ["observation_vector"] - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) - return env - - -def make_parallel_env(env_cfg, state_dict): - num_envs = env_cfg.num_envs - env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg - ) - for t in env.transform: - if isinstance(t, ObservationNorm): - t.init_stats(3, cat_dim=1, reduce_dim=[0, 1]) - env.load_state_dict(state_dict, strict=False) - return env - - -def get_stats(env_cfg): - env = make_transformed_env(make_base_env(env_cfg), env_cfg) - return get_norm_state_dict(env) - - -def init_stats(env, n_samples_stats, from_pixels): - for t in env.transform: - if isinstance(t, ObservationNorm): - if from_pixels: - t.init_stats( - n_samples_stats, - cat_dim=-3, - reduce_dim=(-1, -2, -3), - keep_dims=(-1, -2, -3), - ) - else: - t.init_stats(n_samples_stats) - - -def make_test_env(env_cfg): - env_cfg.num_envs = 1 - state_dict = get_stats(env_cfg) - env = make_parallel_env(env_cfg, state_dict=state_dict) - return env - - -# ==================================================================== -# Collector and replay buffer -# --------------------------- - - -def make_collector(cfg, policy): - env_cfg = cfg.env - collector_cfg = cfg.collector - collector_class = SyncDataCollector - state_dict = get_stats(env_cfg) - collector = collector_class( - make_parallel_env(env_cfg, state_dict=state_dict), - policy, - frames_per_batch=collector_cfg.frames_per_batch, - total_frames=collector_cfg.total_frames, - device=collector_cfg.collector_device, - max_frames_per_traj=collector_cfg.max_frames_per_traj, - ) - return collector - - -# ==================================================================== -# Model -# ----- -# -# We give one version of the model for learning from pixels, and one for state. -# TorchRL comes in handy at this point, as the high-level interactions with -# these models is unchanged, regardless of the modality. - - -def make_a2c_models(cfg): - - env_cfg = cfg.env - from_pixels = env_cfg.from_pixels - proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) - - if not from_pixels: - # we must initialize the observation norm transform - # init_stats( - # proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels - # ) - common_module, policy_module, value_module = make_a2c_modules_state( - proof_environment - ) - else: - common_module, policy_module, value_module = make_a2c_modules_pixels( - proof_environment - ) - - # Wrap modules in a single ActorCritic operator - if common_module is not None: - actor_critic = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_module, - value_operator=value_module, - ) - actor = actor_critic.get_policy_operator() - critic = actor_critic.get_value_head() # to avoid duplicate params - else: - actor = policy_module - critic = value_module - - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor(td) - td = critic(td) - del td - - return actor, critic - - -def make_a2c_modules_state(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["observation_vector"].shape - - # Define distribution class and kwargs - continuous_actions = False - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - continuous_actions = True - num_outputs = proof_environment.action_spec.shape[-1] * 2 - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - "tanh_loc": False, - } - - # Define input keys - in_keys = ["observation_vector"] - - # Define the policy net - policy_net = MLP( - in_features=input_shape[-1], - out_features=num_outputs, - num_cells=[64, 64], - activate_last_layer=False, - activation_class=torch.nn.Tanh, - ) - if continuous_actions: - policy_net = NormalParamWrapper(policy_net) - - policy_module = TensorDictModule( - module=policy_net, - in_keys=in_keys, - out_keys=["loc", "scale"] if continuous_actions else ["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["loc", "scale"] if continuous_actions else ["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define the value net - value_net = MLP( - in_features=input_shape[-1], - out_features=1, - num_cells=[64, 64], - activate_last_layer=False, - activation_class=torch.nn.Tanh, - ) - value_module = ValueOperator( - value_net, - in_keys=in_keys, - ) - - return None, policy_module, value_module - - -def make_a2c_modules_pixels(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["pixels"].shape - - # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - } - - # Define input keys - in_keys = ["pixels"] - - # Define a shared Module and TensorDictModule (CNN + MLP) - common_cnn = ConvNet( - activation_class=torch.nn.ReLU, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones(input_shape)) - common_mlp = MLP( - in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=512, - num_cells=[], - ) - common_mlp_output = common_mlp(common_cnn_output) - - # Define shared net as TensorDictModule - common_module = TensorDictModule( - module=torch.nn.Sequential(common_cnn, common_mlp), - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=common_mlp_output.shape[-1], - out_features=num_outputs, - num_cells=[256], - ) - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP( - in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[256] - ) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -# ==================================================================== -# A2C Loss -# --------- - - -def make_advantage_module(loss_cfg, value_network): - advantage_module = GAE( - gamma=loss_cfg.gamma, - lmbda=loss_cfg.gae_lamdda, - value_network=value_network, - average_gae=True, - ) - return advantage_module - - -def make_loss(loss_cfg, actor_network, value_network): - advantage_module = make_advantage_module(loss_cfg, value_network) - loss_module = A2CLoss( - actor=actor_network, - critic=value_network, - loss_critic_type=loss_cfg.loss_critic_type, - entropy_coef=loss_cfg.entropy_coef, - critic_coef=loss_cfg.critic_coef, - entropy_bonus=True, - ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) - return loss_module, advantage_module - - -def make_optim(optim_cfg, actor_network, value_network): - optim = torch.optim.Adam( - list(actor_network.parameters()) + list(value_network.parameters()), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, - ) - return optim - - -# ==================================================================== -# Logging and recording -# --------------------- - - -def make_logger(logger_cfg): - exp_name = generate_exp_name("A2C", logger_cfg.exp_name) - logger_cfg.exp_name = exp_name - logger = get_logger(logger_cfg.backend, logger_name="a2c", experiment_name=exp_name) - return logger diff --git a/examples/a2c/utils_atari.py b/examples/a2c/utils_atari.py new file mode 100644 index 00000000000..63d15557700 --- /dev/null +++ b/examples/a2c/utils_atari.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + DoubleToFloat, + EndOfLifeTransform, + EnvCreator, + ExplorationType, + GrayScale, + GymEnv, + NoopResetEnv, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + env.append_transform(EndOfLifeTransform()) + return env + + +def make_parallel_env(env_name, num_envs, device, is_test=False): + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, 1, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() + + del proof_environment + + return actor, critic, critic_head + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/examples/a2c/utils_mujoco.py b/examples/a2c/utils_mujoco.py new file mode 100644 index 00000000000..cdc681da522 --- /dev/null +++ b/examples/a2c/utils_mujoco.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + StepCounter, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu"): + env = GymEnv(env_name, device=device) + env = TransformedEnv(env) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + env.append_transform(VecNorm(in_keys=["observation"])) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/examples/cql/utils.py b/examples/cql/utils.py index e458633892c..e67696488c1 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -41,7 +41,7 @@ def apply_env_transforms(env, reward_scaling=1.0): env, Compose( RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat(in_keys=["observation"], in_keys_inv=[]), + DoubleToFloat(), ), ) return transformed_env @@ -129,12 +129,7 @@ def make_offline_replay_buffer(rb_cfg): sampler=SamplerWithoutReplacement(drop_last=False), ) - data.append_transform( - DoubleToFloat( - in_keys=["observation", ("next", "observation")], - in_keys_inv=[], - ) - ) + data.append_transform(DoubleToFloat()) return data @@ -169,8 +164,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 464632f8bf3..5997ccb8fb3 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -1,45 +1,47 @@ -# Environment +# environment and task env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-DDPG" - library: gym - frame_skip: 1 - seed: 1 + exp_name: ${env.name}_DDPG + library: gymnasium + max_episode_steps: 1000 + seed: 42 -# Collection +# collector collector: - total_frames: 1000000 - init_random_frames: 10000 + total_frames: 1_000_000 + init_random_frames: 25_000 frames_per_batch: 1000 - max_frames_per_traj: 1000 init_env_steps: 1000 - async_collection: 1 + reset_at_each_iter: False collector_device: cpu env_per_collector: 1 - num_workers: 1 -# Replay Buffer + +# replay buffer replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay + scratch_dir: ${env.exp_name}_${env.seed} -# Optimization -optimization: +# optimization +optim: utd_ratio: 1.0 gamma: 0.99 - loss_function: smooth_l1 - lr: 3e-4 - weight_decay: 2e-4 + loss_function: l2 + lr: 3.0e-4 + weight_decay: 1e-4 batch_size: 256 target_update_polyak: 0.995 +# network network: hidden_sizes: [256, 256] activation: relu device: "cuda:0" + noise_type: "ou" # ou or gaussian -# Logging +# logging logger: backend: wandb mode: online diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index b77494bc52f..5688e561ae5 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -11,15 +11,19 @@ The helper functions are coded in the utils.py associated with this script. """ +import time + import hydra import numpy as np import torch import torch.cuda import tqdm + from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, make_ddpg_agent, make_environment, @@ -33,6 +37,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + # Create logger exp_name = generate_exp_name("DDPG", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -43,137 +48,145 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) + # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - # Create Environments + # Create environments train_env, eval_env = make_environment(cfg) - # Create Agent + # Create agent model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device) - # Create Loss Module and Target Updater + # Create DDPG loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Make Off-Policy Collector + # Create off-policy collector collector = make_collector(cfg, train_env, exploration_policy) - # Make Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir, device=device, ) - # Make Optimizers + # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) - rewards = [] - rewards_eval = [] - # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip + frames_per_batch = cfg.collector.frames_per_batch eval_iter = cfg.logger.eval_iter - eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip + eval_rollout_steps = cfg.env.max_episode_steps - for i, tensordict in enumerate(collector): + sampling_start = time.time() + for _, tensordict in enumerate(collector): + sampling_time = time.time() - sampling_start + # Update exploration policy exploration_policy.step(tensordict.numel()) - # update weights of the inference policy + + # Update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() pbar.update(tensordict.numel()) tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - # optimization steps + # Optimization steps + training_start = time.time() if collected_frames >= init_random_frames: ( actor_losses, q_losses, ) = ([], []) for _ in range(num_updates): - # sample from replay buffer + # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() - loss_td = loss_module(sampled_tensordict) - + # Update critic + q_loss, *_ = loss_module.loss_value(sampled_tensordict) optimizer_critic.zero_grad() - optimizer_actor.zero_grad() - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_value"] - (actor_loss + q_loss).backward() - + q_loss.backward() optimizer_critic.step() - q_losses.append(q_loss.item()) + # Update actor + actor_loss, *_ = loss_module.loss_actor(sampled_tensordict) + optimizer_actor.zero_grad() + actor_loss.backward() optimizer_actor.step() + + q_losses.append(q_loss.item()) actor_losses.append(actor_loss.item()) - # update qnet_target params + # Update qnet_target params target_net_updater.step() - # update priority + # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - } + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses) + metrics_to_log["train/a_loss"] = np.mean(actor_losses) + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 65c51a0a9c1..5709c3ff59e 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -1,3 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch from torch import nn, optim @@ -10,12 +14,14 @@ EnvCreator, InitTracker, ParallelEnv, + RewardSum, + StepCounter, TransformedEnv, ) -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( + AdditiveGaussianWrapper, MLP, OrnsteinUhlenbeckProcessWrapper, SafeModule, @@ -33,17 +39,23 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) +def env_maker(task, device="cpu", from_pixels=False): + with set_gym_backend("gym"): + return GymEnv( + task, + device=device, + from_pixels=from_pixels, + ) -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms(env, max_episode_steps=1000): transformed_env = TransformedEnv( env, Compose( InitTracker(), - RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat(in_keys=["observation"], in_keys_inv=[]), + StepCounter(max_episode_steps), + DoubleToFloat(), + RewardSum(), ), ) return transformed_env @@ -57,7 +69,9 @@ def make_environment(cfg): ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms(parallel_env) + train_env = apply_env_transforms( + parallel_env, max_episode_steps=cfg.env.max_episode_steps + ) eval_env = TransformedEnv( ParallelEnv( @@ -80,7 +94,8 @@ def make_collector(cfg, train_env, actor_model_explore): train_env, actor_model_explore, frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, + init_random_frames=cfg.collector.init_random_frames, + reset_at_each_iter=cfg.collector.reset_at_each_iter, total_frames=cfg.collector.total_frames, device=cfg.collector.collector_device, ) @@ -128,17 +143,6 @@ def make_replay_buffer( # ----- -def get_activation(cfg): - if cfg.network.activation == "relu": - return nn.ReLU - elif cfg.network.activation == "tanh": - return nn.Tanh - elif cfg.network.activation == "leaky_relu": - return nn.LeakyReLU - else: - raise NotImplementedError - - def make_ddpg_agent(cfg, train_env, eval_env, device): """Make DDPG agent.""" # Define Actor Network @@ -199,10 +203,22 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): eval_env.close() # Exploration wrappers: - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( - model[0], - annealing_num_steps=1_000_000, - ).to(device) + if cfg.network.noise_type == "ou": + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + model[0], + annealing_num_steps=1_000_000, + ).to(device) + elif cfg.network.noise_type == "gaussian": + actor_model_explore = AdditiveGaussianWrapper( + model[0], + sigma_end=1.0, + sigma_init=1.0, + mean=0.0, + std=0.1, + ).to(device) + else: + raise NotImplementedError + return model, actor_model_explore @@ -217,14 +233,14 @@ def make_loss_module(cfg, model): loss_module = DDPGLoss( actor_network=model[0], value_network=model[1], - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, + delay_actor=True, + delay_value=True, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak - ) + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) return loss_module, target_net_updater @@ -233,11 +249,32 @@ def make_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optimization.lr, weight_decay=cfg.optimization.weight_decay + actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay ) optimizer_critic = optim.Adam( critic_params, - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, ) return optimizer_actor, optimizer_critic + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py new file mode 100644 index 00000000000..f241ce4e975 --- /dev/null +++ b/examples/decision_transformer/dt.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Decision Transformer Example. +This is a self-contained example of an offline Decision Transformer training script. +The helper functions are coded in the utils.py associated with this script. +""" +import time + +import hydra +import numpy as np +import torch +import tqdm +from torchrl.envs.libs.gym import set_gym_backend + +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper + +from utils import ( + log_metrics, + make_dt_loss, + make_dt_model, + make_dt_optimizer, + make_env, + make_logger, + make_offline_replay_buffer, +) + + +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden +@hydra.main(config_path=".", config_name="dt_config") +def main(cfg: "DictConfig"): # noqa: F821 + model_device = cfg.optim.device + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create logger + logger = make_logger(cfg) + + # Create offline replay buffer + offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( + cfg.replay_buffer, cfg.env.reward_scaling + ) + + # Create test environment + test_env = make_env(cfg.env, obs_loc, obs_std) + + # Create policy model + actor = make_dt_model(cfg) + policy = actor.to(model_device) + + # Create loss + loss_module = make_dt_loss(cfg.loss, actor) + + # Create optimizer + transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) + + # Create inference policy + inference_policy = DecisionTransformerInferenceWrapper( + policy=policy, + inference_context=cfg.env.inference_context, + ).to(model_device) + + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) + + pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps + clip_grad = cfg.optim.clip_grad + eval_steps = cfg.logger.eval_steps + pretrain_log_interval = cfg.logger.pretrain_log_interval + reward_scaling = cfg.env.reward_scaling + + print(" ***Pretraining*** ") + # Pretraining + start_time = time.time() + for i in range(pretrain_gradient_steps): + pbar.update(i) + + # Sample data + data = offline_buffer.sample() + # Compute loss + loss_vals = loss_module(data.to(model_device)) + transformer_loss = loss_vals["loss"] + + transformer_optim.zero_grad() + torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) + transformer_loss.backward() + transformer_optim.step() + + scheduler.step() + + # Log metrics + to_log = {"train/loss": loss_vals["loss"]} + + # Evaluation + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + if i % pretrain_log_interval == 0: + eval_td = test_env.rollout( + max_steps=eval_steps, + policy=inference_policy, + auto_cast_to_device=True, + ) + to_log["eval/reward"] = ( + eval_td["next", "reward"].sum(1).mean().item() / reward_scaling + ) + if logger is not None: + log_metrics(logger, to_log, i) + + pbar.close() + print(f"Training time: {time.time() - start_time}") + + +if __name__ == "__main__": + main() diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml new file mode 100644 index 00000000000..3514cf2203a --- /dev/null +++ b/examples/decision_transformer/dt_config.yaml @@ -0,0 +1,64 @@ +# environment and task +env: + name: HalfCheetah-v3 + task: "" + library: gym + stacked_frames: 20 + inference_context: 5 + n_samples_stats: 2000 + frame_skip: 1 + num_train_envs: 1 + num_eval_envs: 10 + reward_scaling: 0.001 # for r2g + noop: 1 + seed: 1 + target_return_mode: reduce + eval_target_return: 6000 + collect_target_return: 12000 + +# logger +logger: + backend: wandb + model_name: DT + exp_name: DT-HalfCheetah-medium-v2 + pretrain_log_interval: 500 # record interval in frames + fintune_log_interval: 1 + eval_steps: 1000 + +# replay buffer +replay_buffer: + dataset: halfcheetah-medium-v2 + batch_size: 64 + prb: 0 + stacked_frames: 20 + buffer_prefetch: 64 + capacity: 1_000_000 + buffer_scratch_dir: "/tmp/" + device: cpu + prefetch: 3 + +# optimization +optim: + device: cuda:0 + lr: 1.0e-4 + weight_decay: 5.0e-4 + batch_size: 64 + pretrain_gradient_steps: 55000 + updates_per_episode: 300 + warmup_steps: 10000 + clip_grad: 0.25 + +# loss +loss: + loss_function: "l2" + +# transformer model +transformer: + n_embd: 128 + n_layer: 3 + n_head: 1 + n_inner: 512 # 4*128 + activation: relu + n_positions: 1024 + resid_pdrop: 0.1 + attn_pdrop: 0.1 diff --git a/examples/decision_transformer/lamb.py b/examples/decision_transformer/lamb.py new file mode 100644 index 00000000000..7f874b6e049 --- /dev/null +++ b/examples/decision_transformer/lamb.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# Lamb optimizer directly copied from https://github.com/facebookresearch/online-dt +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False, + ): + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "grad_averaging": grad_averaging, + "max_grad_norm": max_grad_norm, + "trust_clip": trust_clip, + "always_adapt": always_adapt, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]["params"][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor, + ) + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + grad_averaging = 1 if group["grad_averaging"] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group["step"] + bias_correction2 = 1 - beta2 ** group["step"] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state["exp_avg"] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group["weight_decay"] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group["always_adapt"]: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group["trust_clip"]: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group["lr"]) + + return loss diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml new file mode 100644 index 00000000000..f8aebd30091 --- /dev/null +++ b/examples/decision_transformer/odt_config.yaml @@ -0,0 +1,65 @@ +# environment and task +env: + name: HalfCheetah-v3 + task: "" + library: gym + stacked_frames: 20 + inference_context: 5 + n_samples_stats: 2000 + frame_skip: 1 + num_train_envs: 1 + num_eval_envs: 10 + reward_scaling: 0.001 # for r2g + seed: 42 + target_return_mode: reduce + eval_target_return: 6000 + collect_target_return: 12000 + + +# logger +logger: + backend: wandb + exp_name: oDT-HalfCheetah-medium-v2 + model_name: oDT + pretrain_log_interval: 500 # record interval in frames + fintune_log_interval: 1 + eval_steps: 1000 + +# replay buffer +replay_buffer: + dataset: halfcheetah-medium-v2 + batch_size: 256 + prb: 0 + stacked_frames: 20 + buffer_prefetch: 64 + capacity: 1_000_000 + buffer_scratch_dir: "/tmp/" + device: cuda:0 + prefetch: 3 + +# optimizer +optim: + device: cuda:0 + lr: 1.0e-4 + weight_decay: 5.0e-4 + batch_size: 256 + pretrain_gradient_steps: 10000 + updates_per_episode: 300 + warmup_steps: 10000 + clip_grad: 0.25 + +# loss +loss: + alpha_init: 0.1 + target_entropy: auto + +# transformer model +transformer: + n_embd: 512 + n_layer: 4 + n_head: 4 + n_inner: 2048 # 4*512 + activation: relu + n_positions: 1024 + resid_pdrop: 0.1 + attn_pdrop: 0.1 diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py new file mode 100644 index 00000000000..131320e9e21 --- /dev/null +++ b/examples/decision_transformer/online_dt.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Online Decision Transformer Example. +This is a self-contained example of an Online Decision Transformer training script. +The helper functions are coded in the utils.py associated with this script. +""" + +import time + +import hydra +import numpy as np +import torch +import tqdm +from torchrl.envs.libs.gym import set_gym_backend + +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper + +from utils import ( + log_metrics, + make_env, + make_logger, + make_odt_loss, + make_odt_model, + make_odt_optimizer, + make_offline_replay_buffer, +) + + +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden +@hydra.main(config_path=".", config_name="odt_config") +def main(cfg: "DictConfig"): # noqa: F821 + model_device = cfg.optim.device + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create logger + logger = make_logger(cfg) + + # Create offline replay buffer + offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( + cfg.replay_buffer, cfg.env.reward_scaling + ) + + # Create test environment + test_env = make_env(cfg.env, obs_loc, obs_std) + + # Create policy model + actor = make_odt_model(cfg) + policy = actor.to(model_device) + + # Create loss + loss_module = make_odt_loss(cfg.loss, policy) + + # Create optimizer + transformer_optim, temperature_optim, scheduler = make_odt_optimizer( + cfg.optim, loss_module + ) + + # Create inference policy + inference_policy = DecisionTransformerInferenceWrapper( + policy=policy, + inference_context=cfg.env.inference_context, + ).to(model_device) + + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) + + pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps + clip_grad = cfg.optim.clip_grad + eval_steps = cfg.logger.eval_steps + pretrain_log_interval = cfg.logger.pretrain_log_interval + reward_scaling = cfg.env.reward_scaling + + print(" ***Pretraining*** ") + # Pretraining + start_time = time.time() + for i in range(pretrain_gradient_steps): + pbar.update(i) + # Sample data + data = offline_buffer.sample() + # Compute loss + loss_vals = loss_module(data.to(model_device)) + transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] + temperature_loss = loss_vals["loss_alpha"] + + transformer_optim.zero_grad() + torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) + transformer_loss.backward() + transformer_optim.step() + + temperature_optim.zero_grad() + temperature_loss.backward() + temperature_optim.step() + + scheduler.step() + + # Log metrics + to_log = { + "train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), + "train/loss_entropy": loss_vals["loss_entropy"].item(), + "train/loss_alpha": loss_vals["loss_alpha"].item(), + "train/alpha": loss_vals["alpha"].item(), + "train/entropy": loss_vals["entropy"].item(), + } + + # Evaluation + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + inference_policy.eval() + if i % pretrain_log_interval == 0: + eval_td = test_env.rollout( + max_steps=eval_steps, + policy=inference_policy, + auto_cast_to_device=True, + break_when_any_done=False, + ) + inference_policy.train() + to_log["eval/reward"] = ( + eval_td["next", "reward"].sum(1).mean().item() / reward_scaling + ) + + if logger is not None: + log_metrics(logger, to_log, i) + + pbar.close() + print(f"Training time: {time.time() - start_time}") + + +if __name__ == "__main__": + main() diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py new file mode 100644 index 00000000000..595ac5ecf6e --- /dev/null +++ b/examples/decision_transformer/utils.py @@ -0,0 +1,481 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn + +import torch.optim +from lamb import Lamb +from tensordict.nn import TensorDictModule + +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import RandomSampler +from torchrl.envs import ( + CatFrames, + Compose, + DoubleToFloat, + EnvCreator, + ExcludeTransform, + ObservationNorm, + RandomCropTensorDict, + Reward2GoTransform, + RewardScaling, + RewardSum, + SerialEnv, + TargetReturn, + TensorDictPrimer, + TransformedEnv, + UnsqueezeTransform, +) +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import set_gym_backend +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import ( + DTActor, + OnlineDTActor, + ProbabilisticActor, + TanhDelta, + TanhNormal, +) + +from torchrl.objectives import DTLoss, OnlineDTLoss +from torchrl.record.loggers import generate_exp_name, get_logger +from torchrl.trainers.helpers.envs import LIBS + +# ==================================================================== +# Environment utils +# ----------------- + + +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden +def make_base_env(env_cfg): + env_library = LIBS[env_cfg.library] + env_name = env_cfg.name + frame_skip = env_cfg.frame_skip + + env_kwargs = { + "env_name": env_name, + "frame_skip": frame_skip, + } + if env_library is DMControlEnv: + env_task = env_cfg.task + env_kwargs.update({"task_name": env_task}) + env = env_library(**env_kwargs) + return env + + +def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): + transformed_env = TransformedEnv(base_env) + transformed_env.append_transform( + RewardScaling( + loc=0, + scale=env_cfg.reward_scaling, + in_keys=["reward"], + standard_normal=False, + ) + ) + if train: + transformed_env.append_transform( + TargetReturn( + env_cfg.collect_target_return * env_cfg.reward_scaling, + out_keys=["return_to_go_single"], + mode=env_cfg.target_return_mode, + ) + ) + else: + transformed_env.append_transform( + TargetReturn( + env_cfg.eval_target_return * env_cfg.reward_scaling, + out_keys=["return_to_go_single"], + mode=env_cfg.target_return_mode, + ) + ) + + transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) + + transformed_env.append_transform( + DoubleToFloat( + in_keys=["observation"], + in_keys_inv=[], + ) + ) + obsnorm = ObservationNorm( + loc=obs_loc, scale=obs_std, in_keys="observation", standard_normal=True + ) + transformed_env.append_transform(obsnorm) + transformed_env.append_transform( + UnsqueezeTransform( + -2, + in_keys=["observation", "action", "return_to_go_single"], + out_keys=["observation", "action", "return_to_go"], + ) + ) + transformed_env.append_transform( + CatFrames( + in_keys=["observation", "action", "return_to_go"], + N=env_cfg.stacked_frames, + dim=-2, + padding="zeros", + ) + ) + + if train: + transformed_env.append_transform(RewardSum()) + + return transformed_env + + +def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): + if train: + num_envs = env_cfg.num_train_envs + else: + num_envs = env_cfg.num_eval_envs + + def make_env(): + with set_gym_backend("gym"): + return make_base_env(env_cfg) + + env = make_transformed_env( + SerialEnv(num_envs, EnvCreator(make_env)), + env_cfg, + obs_loc, + obs_std, + train, + ) + return env + + +def make_env(env_cfg, obs_loc, obs_std, train=False): + env = make_parallel_env(env_cfg, obs_loc, obs_std, train=train) + return env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, policy): + exclude_target_return = ExcludeTransform( + "return_to_go", + ("next", "return_to_go"), + "return_to_go_single", + ("next", "return_to_go_single"), + ("next", "action"), + ("next", "observation"), + "scale", + "loc", + ) + cat = CatFrames(in_keys=["action"], N=20, dim=-2, padding="zeros") + transforms = Compose( + exclude_target_return, + cat, + ) + collector_cfg = cfg.collector + collector_class = SyncDataCollector + collector = collector_class( + make_env(cfg.env, train=True), + policy, + frames_per_batch=collector_cfg.frames_per_batch, + total_frames=collector_cfg.total_frames, + device=collector_cfg.collector_devices, + max_frames_per_traj=collector_cfg.max_frames_per_traj, + postproc=transforms, + ) + return collector + + +def make_offline_replay_buffer(rb_cfg, reward_scaling): + r2g = Reward2GoTransform( + gamma=1.0, in_keys=["reward"], out_keys=["return_to_go_single"] + ) + reward_scale = RewardScaling( + loc=0, + scale=reward_scaling, + in_keys="return_to_go_single", + out_keys=["return_to_go"], + standard_normal=False, + ) + crop_seq = RandomCropTensorDict(sub_seq_len=rb_cfg.stacked_frames, sample_dim=-1) + + d2f = DoubleToFloat( + in_keys=["observation", ("next", "observation")], + in_keys_inv=[], + ) + exclude = ExcludeTransform( + "next_observations", + # "timeout", + "terminal", + "info", + ("next", "timeout"), + ("next", "terminal"), + ("next", "observation"), + ("next", "info"), + ) + + transforms = Compose( + r2g, + crop_seq, + reward_scale, + d2f, + exclude, + ) + data = D4RLExperienceReplay( + rb_cfg.dataset, + split_trajs=True, + batch_size=rb_cfg.batch_size, + sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False), + transform=transforms, + use_truncated_as_done=True, + ) + full_data = data._get_dataset_from_env(rb_cfg.dataset, {}) + loc = full_data["observation"].mean(axis=0).float() + std = full_data["observation"].std(axis=0).float() + obsnorm = ObservationNorm( + loc=loc, scale=std, in_keys="observation", standard_normal=True + ) + data.append_transform(obsnorm) + return data, loc, std + + +def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): + r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go_single"]) + reward_scale = RewardScaling( + loc=0, + scale=reward_scaling, + in_keys=["return_to_go_single"], + out_keys=["return_to_go"], + standard_normal=False, + ) + catframes = CatFrames( + in_keys=["return_to_go_single"], + out_keys=["return_to_go"], + N=rb_cfg.stacked_frames, + dim=-2, + padding="zeros", + as_inverse=True, + ) + transforms = Compose( + r2g, + reward_scale, + catframes, # TODO: cat frames is not an inverse transform doesnt get triggered! + ) + storage = LazyMemmapStorage( + rb_cfg.capacity, rb_cfg.buffer_scratch_dir, device=rb_cfg.device + ) + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=rb_cfg.prefetch, + storage=storage, + batch_size=rb_cfg.batch_size, + ) + # init buffer with offline data + offline_data = offline_buffer.sample(100000) + offline_data.del_("index") + replay_buffer.extend(offline_data.clone().detach().to_tensordict()) + # add transforms after offline data extension to not trigger reward-to-go calculation + replay_buffer.append_transform(transforms) + + return replay_buffer + + +# ==================================================================== +# Model +# ----- + + +def make_odt_model(cfg): + env_cfg = cfg.env + proof_environment = make_transformed_env( + make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 + ) + + action_spec = proof_environment.action_spec + for key, value in proof_environment.observation_spec.items(): + if key == "observation": + state_dim = value.shape[-1] + in_keys = [ + "observation", + "action", + "return_to_go", + ] + + actor_net = OnlineDTActor( + state_dim=state_dim, + action_dim=action_spec.shape[-1], + transformer_config=cfg.transformer, + ) + + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys, + out_keys=[ + "loc", + "scale", + ], + ) + dist_class = TanhNormal + dist_kwargs = {"min": -1.0, "max": 1.0, "tanh_loc": False, "upscale": 5.0} + + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + out_keys=["action"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + cache_dist=False, + return_log_prob=False, + ) + + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + td = proof_environment.rollout(max_steps=100) + td["action"] = td["next", "action"] + actor(td) + + return actor + + +def make_dt_model(cfg): + env_cfg = cfg.env + proof_environment = make_transformed_env( + make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 + ) + + action_spec = proof_environment.action_spec + for key, value in proof_environment.observation_spec.items(): + if key == "observation": + state_dim = value.shape[-1] + in_keys = [ + "observation", + "action", + "return_to_go", + ] + + actor_net = DTActor( + state_dim=state_dim, + action_dim=action_spec.shape[-1], + transformer_config=cfg.transformer, + ) + + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys, + out_keys=["param"], + ) + dist_class = TanhDelta + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + } + + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["param"], + out_keys=["action"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + cache_dist=False, + return_log_prob=False, + ) + + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + td = proof_environment.rollout(max_steps=100) + td["action"] = td["next", "action"] + actor(td) + + return actor + + +# ==================================================================== +# Online Decision Transformer Loss +# --------- + + +def make_odt_loss(loss_cfg, actor_network): + loss = OnlineDTLoss( + actor_network, + alpha_init=loss_cfg.alpha_init, + target_entropy=loss_cfg.target_entropy, + ) + return loss + + +def make_dt_loss(loss_cfg, actor_network): + loss = DTLoss( + actor_network, + loss_function=loss_cfg.loss_function, + ) + return loss + + +def make_odt_optimizer(optim_cfg, loss_module): + dt_optimizer = Lamb( + loss_module.actor_network_params.flatten_keys().values(), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + eps=1.0e-8, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR( + dt_optimizer, lambda steps: min((steps + 1) / optim_cfg.warmup_steps, 1) + ) + + log_temp_optimizer = torch.optim.Adam( + [loss_module.log_alpha], + lr=1e-4, + betas=[0.9, 0.999], + ) + + return dt_optimizer, log_temp_optimizer, scheduler + + +def make_dt_optimizer(optim_cfg, loss_module): + dt_optimizer = torch.optim.Adam( + loss_module.actor_network_params.flatten_keys().values(), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + eps=1.0e-8, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR( + dt_optimizer, lambda steps: min((steps + 1) / optim_cfg.warmup_steps, 1) + ) + + return dt_optimizer, scheduler + + +# ==================================================================== +# Logging and recording +# --------------------- + + +def make_logger(cfg): + if not cfg.logger.backend: + return None + exp_name = generate_exp_name(cfg.logger.model_name, cfg.logger.exp_name) + cfg.logger.exp_name = exp_name + logger = get_logger( + cfg.logger.backend, + logger_name=cfg.logger.model_name, + experiment_name=exp_name, + wandb_kwargs={"config": cfg}, + ) + return logger + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index c5c03cf7042..12ac76f20e7 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -9,7 +9,7 @@ import torch import torch.cuda import tqdm -from tensordict.nn import InteractionType +from tensordict.nn import InteractionType, TensorDictModule from torch import nn, optim from torchrl.collectors import SyncDataCollector @@ -27,7 +27,7 @@ from torchrl.modules import MLP, SafeModule from torchrl.modules.distributions import OneHotCategorical -from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator +from torchrl.modules.tensordict_module.actors import ProbabilisticActor from torchrl.objectives import DiscreteSACLoss, SoftUpdate from torchrl.record.loggers import generate_exp_name, get_logger @@ -150,8 +150,9 @@ def env_factory(num_workers): **qvalue_net_kwargs, ) - qvalue = ValueOperator( + qvalue = TensorDictModule( in_keys=in_keys, + out_keys=["action_value"], module=qvalue_net, ).to(device) @@ -171,6 +172,7 @@ def env_factory(num_workers): # Create SAC loss loss_module = DiscreteSACLoss( actor_network=model[0], + action_space=test_env.action_spec, qvalue_network=model[1], num_actions=num_actions, num_qvalue_nets=2, diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 4d8b2f37fc4..a5265f442b7 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -58,9 +58,7 @@ Compose( # normalize observations ObservationNorm(in_keys=["observation"]), - DoubleToFloat( - in_keys=["observation"], - ), + DoubleToFloat(), StepCounter(), ), ) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index ae846a0fef6..c16337aa087 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -121,9 +121,7 @@ def make_env_transforms( "reward", ] float_to_double_list += ["action"] # DMControl requires double-precision - env.append_transform( - DoubleToFloat(in_keys=double_to_float_list, in_keys_inv=float_to_double_list) - ) + env.append_transform(DoubleToFloat()) default_dict = { "state": UnboundedContinuousTensorSpec(shape=(*env.batch_size, cfg.state_dim)), diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index cbe9f697a65..16014f4f3ec 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -76,12 +76,14 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.device) exp_name = generate_exp_name("Online_IQL", cfg.exp_name) - logger = get_logger( - logger_type=cfg.logger, - logger_name="iql_logging", - experiment_name=exp_name, - wandb_kwargs={"mode": cfg.mode}, - ) + logger = None + if cfg.logger: + logger = get_logger( + logger_type=cfg.logger, + logger_name="iql_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.mode}, + ) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) @@ -300,8 +302,9 @@ def env_factory(num_workers): "value_loss": np.mean(value_losses), } ) - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) + if logger is not None: + for key, value in train_log.items(): + logger.log_scalar(key, value, step=collected_frames) with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): eval_rollout = test_env.rollout( @@ -312,7 +315,10 @@ def env_factory(num_workers): eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() rewards_eval.append((i, eval_reward)) eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + if logger is not None: + logger.log_scalar( + "test_reward", rewards_eval[-1][1], step=collected_frames + ) if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str diff --git a/examples/multiagent/README.md b/examples/multiagent/README.md new file mode 100644 index 00000000000..7c9e8b708be --- /dev/null +++ b/examples/multiagent/README.md @@ -0,0 +1,69 @@ +# Multi-agent examples + +In this folder we provide a set of multi-agent example scripts using the [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) simulator. + +

+ +

+ +
The MARL algorithms contained in the scripts of this folder run on three multi-robot tasks in VMAS.
+ +For more details on the experiment setup and the environments please refer to the corresponding section of the appendix in the [TorchRL paper](https://arxiv.org/abs/2306.00577). + +## Using the scripts + +### Install + +First you need to install vmas and the dependencies of the scripts. + +Install torchrl and tensordict following repo instructions. + +Install vmas and dependencies: + +```bash +pip install vmas +pip install wandb moviepy +pip install hydra-core +``` + +### Run + +To run the scripts just execute the corresponding python file after having modified the corresponding config +according to your needs. +The config can be found in the .yaml file with the same name. + +For example: +```bash +python mappo_ippo.py +``` + +You can even change the config from the command line like: + +```bash +python mappo_ippo.py --m env.scenario_name=navigation +``` + +### Computational demand +The scripts are set up for collecting many frames, if your compute is limited, you can change the "frames_per_batch" +and "num_epochs" parameters to reduce compute requirements. + +### Script structure + +The scripts are self-contained. +This means that all the code you will need to look at is contained in the script file. +No helper functions are used. + +The structure of scripts follows this order: +- Configuration dictionary for the script +- Environment creation +- Modules creation +- Collector instantiation +- Replay buffer instantiation +- Loss module creation +- Training loop (with inner minibatch loops) +- Evaluation run (at the desired frequency) + +Logging is done by default to wandb. +The logging backend can be changed in the config files to one of "wandb", "tensorboard", "csv", "mlflow". + +All the scripts follow the same on-policy training structure so that results can be compared across different algorithms. diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py new file mode 100644 index 00000000000..351f5c3730e --- /dev/null +++ b/examples/multiagent/iql.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import hydra +import torch + +from tensordict.nn import TensorDictModule +from torch import nn +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential +from torchrl.modules.models.multiagent import MultiAgentMLP +from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators +from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform + + +def rendering_callback(env, td): + env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) + + +@hydra.main(version_base="1.1", config_path=".", config_name="iql") +def train(cfg: "DictConfig"): # noqa: F821 + # Device + cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.env.device = cfg.train.device + + # Seeding + torch.manual_seed(cfg.seed) + + # Sampling + cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps + cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters + cfg.buffer.memory_size = cfg.collector.frames_per_batch + + # Create env and env_test + env = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.env.vmas_envs, + continuous_actions=False, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), + ) + + env_test = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.eval.evaluation_episodes, + continuous_actions=False, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + + # Policy + net = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.space.n, + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + module = TensorDictModule( + net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")] + ) + value_module = QValueModule( + action_value_key=("agents", "action_value"), + out_keys=[ + env.action_key, + ("agents", "action_value"), + ("agents", "chosen_action_value"), + ], + spec=env.unbatched_action_spec, + action_space=None, + ) + qnet = SafeSequential(module, value_module) + + qnet_explore = EGreedyWrapper( + qnet, + eps_init=0.3, + eps_end=0, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + spec=env.unbatched_action_spec, + ) + + collector = SyncDataCollector( + env, + qnet_explore, + device=cfg.env.device, + storing_device=cfg.train.device, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.minibatch_size, + ) + + loss_module = DQNLoss(qnet, delay_value=True) + loss_module.set_keys( + action_value=("agents", "action_value"), + action=env.action_key, + value=("agents", "chosen_action_value"), + reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), + ) + loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) + target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) + + optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr) + + # Logging + if cfg.logger.backend: + model_name = ("Het" if not cfg.model.shared_parameters else "") + "IQL" + logger = init_logging(cfg, model_name) + + total_time = 0 + total_frames = 0 + sampling_start = time.time() + for i, tensordict_data in enumerate(collector): + print(f"\nIteration {i}") + + sampling_time = time.time() - sampling_start + + current_frames = tensordict_data.numel() + total_frames += current_frames + data_view = tensordict_data.reshape(-1) + replay_buffer.extend(data_view) + + training_tds = [] + training_start = time.time() + for _ in range(cfg.train.num_epochs): + for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + training_tds.append(loss_vals.detach()) + + loss_value = loss_vals["loss"] + + loss_value.backward() + + total_norm = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), cfg.train.max_grad_norm + ) + training_tds[-1].set("grad_norm", total_norm.mean()) + + optim.step() + optim.zero_grad() + target_net_updater.step() + + qnet_explore.step(frames=current_frames) # Update exploration annealing + collector.update_policy_weights_() + + training_time = time.time() - training_start + + iteration_time = sampling_time + training_time + total_time += iteration_time + training_tds = torch.stack(training_tds) + + # More logs + if cfg.logger.backend: + log_training( + logger, + training_tds, + tensordict_data, + sampling_time, + training_time, + total_time, + i, + current_frames, + total_frames, + step=i, + ) + + if ( + cfg.eval.evaluation_episodes > 0 + and i % cfg.eval.evaluation_interval == 0 + and cfg.logger.backend + ): + evaluation_start = time.time() + with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + env_test.frames = [] + rollouts = env_test.rollout( + max_steps=cfg.env.max_steps, + policy=qnet, + callback=rendering_callback, + auto_cast_to_device=True, + break_when_any_done=False, + # We are running vectorized evaluation we do not want it to stop when just one env is done + ) + + evaluation_time = time.time() - evaluation_start + + log_evaluation(logger, rollouts, env_test, evaluation_time, step=i) + + if cfg.logger.backend == "wandb": + logger.experiment.log({}, commit=True) + sampling_start = time.time() + + +if __name__ == "__main__": + train() diff --git a/examples/multiagent/iql.yaml b/examples/multiagent/iql.yaml new file mode 100644 index 00000000000..801503b7e9d --- /dev/null +++ b/examples/multiagent/iql.yaml @@ -0,0 +1,38 @@ +seed: 0 + +env: + max_steps: 100 + scenario_name: "balance" + scenario: + n_agents: 3 + device: ??? # These values will be populated dynamically + vmas_envs: ??? + +model: + shared_parameters: True + +collector: + frames_per_batch: 60_000 # Frames sampled each sampling iteration + n_iters: 500 # Number of sampling/training iterations + total_frames: ??? + +buffer: + memory_size: ??? + +loss: + gamma: 0.9 + tau: 0.005 # For target net + +train: + num_epochs: 45 # optimization steps per batch of data collected + minibatch_size: 4096 # size of minibatches used in each epoch + lr: 5e-5 + max_grad_norm: 40.0 + device: ??? + +eval: + evaluation_interval: 20 + evaluation_episodes: 200 + +logger: + backend: wandb # Delete to remove logging diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py new file mode 100644 index 00000000000..9301f8a63f2 --- /dev/null +++ b/examples/multiagent/maddpg_iddpg.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import hydra +import torch + +from tensordict.nn import TensorDictModule +from torch import nn +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ( + AdditiveGaussianWrapper, + ProbabilisticActor, + TanhDelta, + ValueOperator, +) +from torchrl.modules.models.multiagent import MultiAgentMLP +from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators +from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform + + +def rendering_callback(env, td): + env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) + + +@hydra.main(version_base="1.1", config_path=".", config_name="maddpg_iddpg") +def train(cfg: "DictConfig"): # noqa: F821 + # Device + cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.env.device = cfg.train.device + + # Seeding + torch.manual_seed(cfg.seed) + + # Sampling + cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps + cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters + cfg.buffer.memory_size = cfg.collector.frames_per_batch + + # Create env and env_test + env = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.env.vmas_envs, + continuous_actions=True, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), + ) + + env_test = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.eval.evaluation_episodes, + continuous_actions=True, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + + # Policy + actor_net = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.shape[-1], + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + policy_module = TensorDictModule( + actor_net, in_keys=[("agents", "observation")], out_keys=[("agents", "param")] + ) + policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "param")], + out_keys=[env.action_key], + distribution_class=TanhDelta, + distribution_kwargs={ + "min": env.unbatched_action_spec[("agents", "action")].space.low, + "max": env.unbatched_action_spec[("agents", "action")].space.high, + }, + return_log_prob=False, + ) + + policy_explore = AdditiveGaussianWrapper( + policy, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + ) + + # Critic + module = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1] + + env.action_spec.shape[-1], # Q critic takes action and value + n_agent_outputs=1, + n_agents=env.n_agents, + centralised=cfg.model.centralised_critic, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation"), env.action_key], + out_keys=[("agents", "state_action_value")], + ) + + collector = SyncDataCollector( + env, + policy_explore, + device=cfg.env.device, + storing_device=cfg.train.device, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.minibatch_size, + ) + + loss_module = DDPGLoss( + actor_network=policy, value_network=value_module, delay_value=True + ) + loss_module.set_keys( + state_action_value=("agents", "state_action_value"), + reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), + ) + loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) + target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) + + optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr) + + # Logging + if cfg.logger.backend: + model_name = ( + ("Het" if not cfg.model.shared_parameters else "") + + ("MA" if cfg.model.centralised_critic else "I") + + "DDPG" + ) + logger = init_logging(cfg, model_name) + + total_time = 0 + total_frames = 0 + sampling_start = time.time() + for i, tensordict_data in enumerate(collector): + print(f"\nIteration {i}") + + sampling_time = time.time() - sampling_start + + current_frames = tensordict_data.numel() + total_frames += current_frames + data_view = tensordict_data.reshape(-1) + replay_buffer.extend(data_view) + + training_tds = [] + training_start = time.time() + for _ in range(cfg.train.num_epochs): + for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + training_tds.append(loss_vals.detach()) + + loss_value = loss_vals["loss_actor"] + loss_vals["loss_value"] + + loss_value.backward() + + total_norm = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), cfg.train.max_grad_norm + ) + training_tds[-1].set("grad_norm", total_norm.mean()) + + optim.step() + optim.zero_grad() + target_net_updater.step() + + policy_explore.step(frames=current_frames) # Update exploration annealing + collector.update_policy_weights_() + + training_time = time.time() - training_start + + iteration_time = sampling_time + training_time + total_time += iteration_time + training_tds = torch.stack(training_tds) + + # More logs + if cfg.logger.backend: + log_training( + logger, + training_tds, + tensordict_data, + sampling_time, + training_time, + total_time, + i, + current_frames, + total_frames, + step=i, + ) + + if ( + cfg.eval.evaluation_episodes > 0 + and i % cfg.eval.evaluation_interval == 0 + and cfg.logger.backend + ): + evaluation_start = time.time() + with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + env_test.frames = [] + rollouts = env_test.rollout( + max_steps=cfg.env.max_steps, + policy=policy, + callback=rendering_callback, + auto_cast_to_device=True, + break_when_any_done=False, + # We are running vectorized evaluation we do not want it to stop when just one env is done + ) + + evaluation_time = time.time() - evaluation_start + + log_evaluation(logger, rollouts, env_test, evaluation_time, step=i) + + if cfg.logger.backend == "wandb": + logger.experiment.log({}, commit=True) + sampling_start = time.time() + + +if __name__ == "__main__": + train() diff --git a/examples/multiagent/maddpg_iddpg.yaml b/examples/multiagent/maddpg_iddpg.yaml new file mode 100644 index 00000000000..19328cbc39e --- /dev/null +++ b/examples/multiagent/maddpg_iddpg.yaml @@ -0,0 +1,39 @@ +seed: 0 + +env: + max_steps: 100 + scenario_name: "balance" + scenario: + n_agents: 3 + device: ??? # These values will be populated dynamically + vmas_envs: ??? + +model: + shared_parameters: False # MADDPG paper does not use shared params because reward function can be different + centralised_critic: True # MADDPG if True, IDDPG if False + +collector: + frames_per_batch: 60_000 # Frames sampled each sampling iteration + n_iters: 500 # Number of sampling/training iterations + total_frames: ??? + +buffer: + memory_size: ??? + +loss: + gamma: 0.9 + tau: 0.005 # For target net + +train: + num_epochs: 45 # optimization steps per batch of data collected + minibatch_size: 4096 # size of minibatches used in each epoch + lr: 5e-5 + max_grad_norm: 40.0 + device: ??? + +eval: + evaluation_interval: 20 + evaluation_episodes: 200 + +logger: + backend: wandb # Delete to remove logging diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py new file mode 100644 index 00000000000..c2e46174e92 --- /dev/null +++ b/examples/multiagent/mappo_ippo.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +import hydra +import torch + +from tensordict.nn import TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.modules.models.multiagent import MultiAgentMLP +from torchrl.objectives import ClipPPOLoss, ValueEstimators +from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform + + +def rendering_callback(env, td): + env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) + + +@hydra.main(version_base="1.1", config_path=".", config_name="mappo_ippo") +def train(cfg: "DictConfig"): # noqa: F821 + # Device + cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.env.device = cfg.train.device + + # Seeding + torch.manual_seed(cfg.seed) + + # Sampling + cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps + cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters + cfg.buffer.memory_size = cfg.collector.frames_per_batch + + # Create env and env_test + env = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.env.vmas_envs, + continuous_actions=True, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), + ) + + env_test = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.eval.evaluation_episodes, + continuous_actions=True, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + + # Policy + actor_net = nn.Sequential( + MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=2 * env.action_spec.shape[-1], + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + NormalParamExtractor(), + ) + policy_module = TensorDictModule( + actor_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "loc"), ("agents", "scale")], + ) + policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "loc"), ("agents", "scale")], + out_keys=[env.action_key], + distribution_class=TanhNormal, + distribution_kwargs={ + "min": env.unbatched_action_spec[("agents", "action")].space.low, + "max": env.unbatched_action_spec[("agents", "action")].space.high, + }, + return_log_prob=True, + ) + + # Critic + module = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=1, + n_agents=env.n_agents, + centralised=cfg.model.centralised_critic, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation")], + ) + + collector = SyncDataCollector( + env, + policy, + device=cfg.env.device, + storing_device=cfg.train.device, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.minibatch_size, + ) + + # Loss + loss_module = ClipPPOLoss( + actor=policy, + critic=value_module, + clip_epsilon=cfg.loss.clip_epsilon, + entropy_coef=cfg.loss.entropy_eps, + normalize_advantage=False, + ) + loss_module.set_keys( + reward=env.reward_key, + action=env.action_key, + done=("agents", "done"), + terminated=("agents", "terminated"), + ) + loss_module.make_value_estimator( + ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda + ) + optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr) + + # Logging + if cfg.logger.backend: + model_name = ( + ("Het" if not cfg.model.shared_parameters else "") + + ("MA" if cfg.model.centralised_critic else "I") + + "PPO" + ) + logger = init_logging(cfg, model_name) + + total_time = 0 + total_frames = 0 + sampling_start = time.time() + for i, tensordict_data in enumerate(collector): + print(f"\nIteration {i}") + + sampling_time = time.time() - sampling_start + + with torch.no_grad(): + loss_module.value_estimator( + tensordict_data, + params=loss_module.critic_params, + target_params=loss_module.target_critic_params, + ) + current_frames = tensordict_data.numel() + total_frames += current_frames + data_view = tensordict_data.reshape(-1) + replay_buffer.extend(data_view) + + training_tds = [] + training_start = time.time() + for _ in range(cfg.train.num_epochs): + for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + training_tds.append(loss_vals.detach()) + + loss_value = ( + loss_vals["loss_objective"] + + loss_vals["loss_critic"] + + loss_vals["loss_entropy"] + ) + + loss_value.backward() + + total_norm = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), cfg.train.max_grad_norm + ) + training_tds[-1].set("grad_norm", total_norm.mean()) + + optim.step() + optim.zero_grad() + + collector.update_policy_weights_() + + training_time = time.time() - training_start + + iteration_time = sampling_time + training_time + total_time += iteration_time + training_tds = torch.stack(training_tds) + + # More logs + if cfg.logger.backend: + log_training( + logger, + training_tds, + tensordict_data, + sampling_time, + training_time, + total_time, + i, + current_frames, + total_frames, + step=i, + ) + + if ( + cfg.eval.evaluation_episodes > 0 + and i % cfg.eval.evaluation_interval == 0 + and cfg.logger.backend + ): + evaluation_start = time.time() + with torch.no_grad(): + env_test.frames = [] + rollouts = env_test.rollout( + max_steps=cfg.env.max_steps, + policy=policy, + callback=rendering_callback, + auto_cast_to_device=True, + break_when_any_done=False, + # We are running vectorized evaluation we do not want it to stop when just one env is done + ) + + evaluation_time = time.time() - evaluation_start + + log_evaluation(logger, rollouts, env_test, evaluation_time, step=i) + + if cfg.logger.backend == "wandb": + logger.experiment.log({}, commit=True) + sampling_start = time.time() + + +if __name__ == "__main__": + train() diff --git a/examples/multiagent/mappo_ippo.yaml b/examples/multiagent/mappo_ippo.yaml new file mode 100644 index 00000000000..befec1cf1ca --- /dev/null +++ b/examples/multiagent/mappo_ippo.yaml @@ -0,0 +1,41 @@ +seed: 0 + +env: + max_steps: 100 + scenario_name: "balance" + scenario: + n_agents: 3 + device: ??? # These values will be populated dynamically + vmas_envs: ??? + +model: + shared_parameters: True + centralised_critic: True # MAPPO if True, IPPO if False + +collector: + frames_per_batch: 60_000 # Frames sampled each sampling iteration + n_iters: 500 # Number of sampling/training iterations + total_frames: ??? + +buffer: + memory_size: ??? + +loss: + gamma: 0.9 + lmbda: 0.9 + entropy_eps: 0 + clip_epsilon: 0.2 + +train: + num_epochs: 45 # optimization steps per batch of data collected + minibatch_size: 4096 # size of minibatches used in each epoch + lr: 5e-5 + max_grad_norm: 40.0 + device: ??? + +eval: + evaluation_interval: 20 + evaluation_episodes: 200 + +logger: + backend: wandb # Delete to remove logging diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py new file mode 100644 index 00000000000..222e0434db2 --- /dev/null +++ b/examples/multiagent/qmix_vdn.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import hydra +import torch + +from tensordict.nn import TensorDictModule +from torch import nn +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential +from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer +from torchrl.objectives import SoftUpdate, ValueEstimators +from torchrl.objectives.multiagent.qmixer import QMixerLoss +from utils.logging import init_logging, log_evaluation, log_training + + +def rendering_callback(env, td): + env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) + + +@hydra.main(version_base="1.1", config_path=".", config_name="qmix_vdn") +def train(cfg: "DictConfig"): # noqa: F821 + # Device + cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.env.device = cfg.train.device + + # Seeding + torch.manual_seed(cfg.seed) + + # Sampling + cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps + cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters + cfg.buffer.memory_size = cfg.collector.frames_per_batch + + # Create env and env_test + env = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.env.vmas_envs, + continuous_actions=False, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), + ) + + env_test = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.eval.evaluation_episodes, + continuous_actions=False, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + + # Policy + net = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.space.n, + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + module = TensorDictModule( + net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")] + ) + value_module = QValueModule( + action_value_key=("agents", "action_value"), + out_keys=[ + env.action_key, + ("agents", "action_value"), + ("agents", "chosen_action_value"), + ], + spec=env.unbatched_action_spec, + action_space=None, + ) + qnet = SafeSequential(module, value_module) + + qnet_explore = EGreedyWrapper( + qnet, + eps_init=0.3, + eps_end=0, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + spec=env.unbatched_action_spec, + ) + + if cfg.loss.mixer_type == "qmix": + mixer = TensorDictModule( + module=QMixer( + state_shape=env.unbatched_observation_spec[ + "agents", "observation" + ].shape, + mixing_embed_dim=32, + n_agents=env.n_agents, + device=cfg.train.device, + ), + in_keys=[("agents", "chosen_action_value"), ("agents", "observation")], + out_keys=["chosen_action_value"], + ) + elif cfg.loss.mixer_type == "vdn": + mixer = TensorDictModule( + module=VDNMixer( + n_agents=env.n_agents, + device=cfg.train.device, + ), + in_keys=[("agents", "chosen_action_value")], + out_keys=["chosen_action_value"], + ) + else: + raise ValueError("Mixer type not in the example") + + collector = SyncDataCollector( + env, + qnet_explore, + device=cfg.env.device, + storing_device=cfg.train.device, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.minibatch_size, + ) + + loss_module = QMixerLoss(qnet, mixer, delay_value=True) + loss_module.set_keys( + action_value=("agents", "action_value"), + local_value=("agents", "chosen_action_value"), + global_value="chosen_action_value", + action=env.action_key, + ) + loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) + target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) + + optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr) + + # Logging + if cfg.logger.backend: + model_name = ( + "Het" if not cfg.model.shared_parameters else "" + ) + cfg.loss.mixer_type.upper() + logger = init_logging(cfg, model_name) + + total_time = 0 + total_frames = 0 + sampling_start = time.time() + for i, tensordict_data in enumerate(collector): + print(f"\nIteration {i}") + + sampling_time = time.time() - sampling_start + + # Remove agent dimension from reward (since it is shared in QMIX/VDN) + tensordict_data.set( + ("next", "reward"), tensordict_data.get(("next", env.reward_key)).mean(-2) + ) + del tensordict_data["next", env.reward_key] + tensordict_data.set( + ("next", "episode_reward"), + tensordict_data.get(("next", "agents", "episode_reward")).mean(-2), + ) + del tensordict_data["next", "agents", "episode_reward"] + + current_frames = tensordict_data.numel() + total_frames += current_frames + data_view = tensordict_data.reshape(-1) + replay_buffer.extend(data_view) + + training_tds = [] + training_start = time.time() + for _ in range(cfg.train.num_epochs): + for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + training_tds.append(loss_vals.detach()) + + loss_value = loss_vals["loss"] + + loss_value.backward() + + total_norm = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), cfg.train.max_grad_norm + ) + training_tds[-1].set("grad_norm", total_norm.mean()) + + optim.step() + optim.zero_grad() + target_net_updater.step() + + qnet_explore.step(frames=current_frames) # Update exploration annealing + collector.update_policy_weights_() + + training_time = time.time() - training_start + + iteration_time = sampling_time + training_time + total_time += iteration_time + training_tds = torch.stack(training_tds) + + # More logs + if cfg.logger.backend: + log_training( + logger, + training_tds, + tensordict_data, + sampling_time, + training_time, + total_time, + i, + current_frames, + total_frames, + step=i, + ) + + if ( + cfg.eval.evaluation_episodes > 0 + and i % cfg.eval.evaluation_interval == 0 + and cfg.logger.backend + ): + evaluation_start = time.time() + with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + env_test.frames = [] + rollouts = env_test.rollout( + max_steps=cfg.env.max_steps, + policy=qnet, + callback=rendering_callback, + auto_cast_to_device=True, + break_when_any_done=False, + # We are running vectorized evaluation we do not want it to stop when just one env is done + ) + + evaluation_time = time.time() - evaluation_start + + log_evaluation(logger, rollouts, env_test, evaluation_time, step=i) + + if cfg.logger.backend == "wandb": + logger.experiment.log({}, commit=True) + sampling_start = time.time() + + +if __name__ == "__main__": + train() diff --git a/examples/multiagent/qmix_vdn.yaml b/examples/multiagent/qmix_vdn.yaml new file mode 100644 index 00000000000..a78b3987ffb --- /dev/null +++ b/examples/multiagent/qmix_vdn.yaml @@ -0,0 +1,39 @@ +seed: 0 + +env: + max_steps: 100 + scenario_name: "balance" + scenario: + n_agents: 3 + device: ??? # These values will be populated dynamically + vmas_envs: ??? + +model: + shared_parameters: True + +collector: + frames_per_batch: 60_000 # Frames sampled each sampling iteration + n_iters: 500 # Number of sampling/training iterations + total_frames: ??? + +buffer: + memory_size: ??? + +loss: + mixer_type: "qmix" # or "vdn" + gamma: 0.9 + tau: 0.005 # For target net + +train: + num_epochs: 45 # optimization steps per batch of data collected + minibatch_size: 4096 # size of minibatches used in each epoch + lr: 5e-5 + max_grad_norm: 40.0 + device: ??? + +eval: + evaluation_interval: 20 + evaluation_episodes: 200 + +logger: + backend: wandb # Delete to remove logging diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py new file mode 100644 index 00000000000..fb184291c90 --- /dev/null +++ b/examples/multiagent/sac.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import hydra +import torch + +from tensordict.nn import TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn +from torch.distributions import Categorical, OneHotCategorical +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.modules.models.multiagent import MultiAgentMLP +from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators +from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform + + +def rendering_callback(env, td): + env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) + + +@hydra.main(version_base="1.1", config_path=".", config_name="sac") +def train(cfg: "DictConfig"): # noqa: F821 + # Device + cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.env.device = cfg.train.device + + # Seeding + torch.manual_seed(cfg.seed) + + # Sampling + cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps + cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters + cfg.buffer.memory_size = cfg.collector.frames_per_batch + + # Create env and env_test + env = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.env.vmas_envs, + continuous_actions=cfg.env.continuous_actions, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + categorical_actions=cfg.env.categorical_actions, + # Scenario kwargs + **cfg.env.scenario, + ) + env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), + ) + + env_test = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.eval.evaluation_episodes, + continuous_actions=cfg.env.continuous_actions, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + + # Policy + if cfg.env.continuous_actions: + actor_net = nn.Sequential( + MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=2 * env.action_spec.shape[-1], + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + NormalParamExtractor(), + ) + policy_module = TensorDictModule( + actor_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "loc"), ("agents", "scale")], + ) + + policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "loc"), ("agents", "scale")], + out_keys=[env.action_key], + distribution_class=TanhNormal, + distribution_kwargs={ + "min": env.unbatched_action_spec[("agents", "action")].space.low, + "max": env.unbatched_action_spec[("agents", "action")].space.high, + }, + return_log_prob=True, + ) + + # Critic + module = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1] + + env.action_spec.shape[-1], # Q critic takes action and value + n_agent_outputs=1, + n_agents=env.n_agents, + centralised=cfg.model.centralised_critic, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation"), env.action_key], + out_keys=[("agents", "state_action_value")], + ) + else: + actor_net = nn.Sequential( + MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.space.n, + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + ) + policy_module = TensorDictModule( + actor_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "logits")], + ) + policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "logits")], + out_keys=[env.action_key], + distribution_class=OneHotCategorical + if not cfg.env.categorical_actions + else Categorical, + return_log_prob=True, + ) + + # Critic + module = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.space.n, + n_agents=env.n_agents, + centralised=cfg.model.centralised_critic, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation")], + out_keys=[("agents", "action_value")], + ) + + collector = SyncDataCollector( + env, + policy, + device=cfg.env.device, + storing_device=cfg.train.device, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.minibatch_size, + ) + + if cfg.env.continuous_actions: + loss_module = SACLoss( + actor_network=policy, + qvalue_network=value_module, + delay_qvalue=True, + action_spec=env.unbatched_action_spec, + ) + loss_module.set_keys( + state_action_value=("agents", "state_action_value"), + action=env.action_key, + reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), + ) + else: + loss_module = DiscreteSACLoss( + actor_network=policy, + qvalue_network=value_module, + delay_qvalue=True, + num_actions=env.action_spec.space.n, + action_space=env.unbatched_action_spec, + ) + loss_module.set_keys( + action_value=("agents", "action_value"), + action=env.action_key, + reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), + ) + + loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) + target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) + + optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr) + + # Logging + if cfg.logger.backend: + model_name = ( + ("Het" if not cfg.model.shared_parameters else "") + + ("MA" if cfg.model.centralised_critic else "I") + + "SAC" + ) + logger = init_logging(cfg, model_name) + + total_time = 0 + total_frames = 0 + sampling_start = time.time() + for i, tensordict_data in enumerate(collector): + print(f"\nIteration {i}") + + sampling_time = time.time() - sampling_start + + current_frames = tensordict_data.numel() + total_frames += current_frames + data_view = tensordict_data.reshape(-1) + replay_buffer.extend(data_view) + + training_tds = [] + training_start = time.time() + for _ in range(cfg.train.num_epochs): + for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + training_tds.append(loss_vals.detach()) + + loss_value = ( + loss_vals["loss_actor"] + + loss_vals["loss_alpha"] + + loss_vals["loss_qvalue"] + ) + + loss_value.backward() + + total_norm = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), cfg.train.max_grad_norm + ) + training_tds[-1].set("grad_norm", total_norm.mean()) + + optim.step() + optim.zero_grad() + target_net_updater.step() + + collector.update_policy_weights_() + + training_time = time.time() - training_start + + iteration_time = sampling_time + training_time + total_time += iteration_time + training_tds = torch.stack(training_tds) + + # More logs + if cfg.logger.backend: + log_training( + logger, + training_tds, + tensordict_data, + sampling_time, + training_time, + total_time, + i, + current_frames, + total_frames, + step=i, + ) + + if ( + cfg.eval.evaluation_episodes > 0 + and i % cfg.eval.evaluation_interval == 0 + and cfg.logger.backend + ): + evaluation_start = time.time() + with torch.no_grad() and set_exploration_type(ExplorationType.MODE): + env_test.frames = [] + rollouts = env_test.rollout( + max_steps=cfg.env.max_steps, + policy=policy, + callback=rendering_callback, + auto_cast_to_device=True, + break_when_any_done=False, + # We are running vectorized evaluation we do not want it to stop when just one env is done + ) + + evaluation_time = time.time() - evaluation_start + + log_evaluation(logger, rollouts, env_test, evaluation_time, step=i) + + if cfg.logger.backend == "wandb": + logger.experiment.log({}, commit=True) + sampling_start = time.time() + + +if __name__ == "__main__": + train() diff --git a/examples/multiagent/sac.yaml b/examples/multiagent/sac.yaml new file mode 100644 index 00000000000..ab478ab0dc8 --- /dev/null +++ b/examples/multiagent/sac.yaml @@ -0,0 +1,41 @@ +seed: 0 + +env: + continuous_actions: True # False for discrete sac + categorical_actions: False + max_steps: 100 + scenario_name: "balance" + scenario: + n_agents: 3 + device: ??? # These values will be populated dynamically + vmas_envs: ??? + +model: + shared_parameters: True + centralised_critic: True + +collector: + frames_per_batch: 60_000 # Frames sampled each sampling iteration + n_iters: 500 # Number of sampling/training iterations + total_frames: ??? + +buffer: + memory_size: ??? + +loss: + gamma: 0.9 + tau: 0.005 # For target net + +train: + num_epochs: 45 # optimization steps per batch of data collected + minibatch_size: 4096 # size of minibatches used in each epoch + lr: 5e-5 + max_grad_norm: 2.0 + device: ??? + +eval: + evaluation_interval: 20 + evaluation_episodes: 200 + +logger: + backend: wandb # Delete to remove logging diff --git a/examples/multiagent/utils/__init__.py b/examples/multiagent/utils/__init__.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/examples/multiagent/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/multiagent/utils/logging.py b/examples/multiagent/utils/logging.py new file mode 100644 index 00000000000..352d0addc51 --- /dev/null +++ b/examples/multiagent/utils/logging.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os + +import numpy as np +import torch +from tensordict import TensorDictBase +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.record.loggers import generate_exp_name, get_logger, Logger +from torchrl.record.loggers.wandb import WandbLogger + + +def init_logging(cfg, model_name: str): + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name=os.getcwd(), + experiment_name=generate_exp_name(cfg.env.scenario_name, model_name), + wandb_kwargs={ + "group": model_name, + "project": f"torchrl_{cfg.env.scenario_name}", + }, + ) + logger.log_hparams(cfg) + return logger + + +def log_training( + logger: Logger, + training_td: TensorDictBase, + sampling_td: TensorDictBase, + sampling_time: float, + training_time: float, + total_time: float, + iteration: int, + current_frames: int, + total_frames: int, + step: int, +): + if ("next", "agents", "reward") not in sampling_td.keys(True, True): + sampling_td.set( + ("next", "agents", "reward"), + sampling_td.get(("next", "reward")) + .expand(sampling_td.get("agents").shape) + .unsqueeze(-1), + ) + if ("next", "agents", "episode_reward") not in sampling_td.keys(True, True): + sampling_td.set( + ("next", "agents", "episode_reward"), + sampling_td.get(("next", "episode_reward")) + .expand(sampling_td.get("agents").shape) + .unsqueeze(-1), + ) + + to_log = { + f"train/learner/{key}": value.mean().item() + for key, value in training_td.items() + } + + if "info" in sampling_td.get("agents").keys(): + to_log.update( + { + f"train/info/{key}": value.mean().item() + for key, value in sampling_td.get(("agents", "info")).items() + } + ) + + reward = sampling_td.get(("next", "agents", "reward")).mean(-2) # Mean over agents + done = sampling_td.get(("next", "done")) + if done.ndim > reward.ndim: + done = done[..., 0, :] # Remove expanded agent dim + episode_reward = sampling_td.get(("next", "agents", "episode_reward")).mean(-2)[ + done + ] + to_log.update( + { + "train/reward/reward_min": reward.min().item(), + "train/reward/reward_mean": reward.mean().item(), + "train/reward/reward_max": reward.max().item(), + "train/reward/episode_reward_min": episode_reward.min().item(), + "train/reward/episode_reward_mean": episode_reward.mean().item(), + "train/reward/episode_reward_max": episode_reward.max().item(), + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/iteration_time": training_time + sampling_time, + "train/total_time": total_time, + "train/training_iteration": iteration, + "train/current_frames": current_frames, + "train/total_frames": total_frames, + } + ) + if isinstance(logger, WandbLogger): + logger.experiment.log(to_log, commit=False) + else: + for key, value in to_log.items(): + logger.log_scalar(key.replace("/", "_"), value, step=step) + + return to_log + + +def log_evaluation( + logger: WandbLogger, + rollouts: TensorDictBase, + env_test: VmasEnv, + evaluation_time: float, + step: int, +): + rollouts = list(rollouts.unbind(0)) + for k, r in enumerate(rollouts): + next_done = r.get(("next", "done")).sum( + tuple(range(r.batch_dims, r.get(("next", "done")).ndim)), + dtype=torch.bool, + ) + done_index = next_done.nonzero(as_tuple=True)[0][ + 0 + ] # First done index for this traj + rollouts[k] = r[: done_index + 1] + + rewards = [td.get(("next", "agents", "reward")).sum(0).mean() for td in rollouts] + to_log = { + "eval/episode_reward_min": min(rewards), + "eval/episode_reward_max": max(rewards), + "eval/episode_reward_mean": sum(rewards) / len(rollouts), + "eval/episode_len_mean": sum([td.batch_size[0] for td in rollouts]) + / len(rollouts), + "eval/evaluation_time": evaluation_time, + } + + vid = torch.tensor( + np.transpose(env_test.frames[: rollouts[0].batch_size[0]], (0, 3, 1, 2)), + dtype=torch.uint8, + ).unsqueeze(0) + + if isinstance(logger, WandbLogger): + import wandb + + logger.experiment.log(to_log, commit=False) + logger.experiment.log( + { + "eval/video": wandb.Video(vid, fps=1 / env_test.world.dt, format="mp4"), + }, + commit=False, + ) + else: + for key, value in to_log.items(): + logger.log_scalar(key.replace("/", "_"), value, step=step) + logger.log_video("eval_video", vid, step=step) diff --git a/examples/multiagent/utils/utils.py b/examples/multiagent/utils/utils.py new file mode 100644 index 00000000000..d21bafdf691 --- /dev/null +++ b/examples/multiagent/utils/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from tensordict import unravel_key +from torchrl.envs import Transform + + +def swap_last(source, dest): + source = unravel_key(source) + dest = unravel_key(dest) + if isinstance(source, str): + if isinstance(dest, str): + return dest + return dest[-1] + if isinstance(dest, str): + return source[:-1] + (dest,) + return source[:-1] + (dest[-1],) + + +class DoneTransform(Transform): + """Expands the 'done' entries (incl. terminated) to match the reward shape. + + Can be appended to a replay buffer or a collector. + """ + + def __init__(self, reward_key, done_keys): + super().__init__() + self.reward_key = reward_key + self.done_keys = done_keys + + def forward(self, tensordict): + for done_key in self.done_keys: + new_name = swap_last(self.reward_key, done_key) + tensordict.set( + ("next", new_name), + tensordict.get(("next", done_key)) + .unsqueeze(-1) + .expand(tensordict.get(("next", self.reward_key)).shape), + ) + return tensordict diff --git a/examples/ppo/README.md b/examples/ppo/README.md new file mode 100644 index 00000000000..7d27f746e2a --- /dev/null +++ b/examples/ppo/README.md @@ -0,0 +1,29 @@ +## Reproducing Proximal Policy Optimization (PPO) Algorithm Results + +This repository contains scripts that enable training agents using the Proximal Policy Optimization (PPO) Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Schulman et al. (2017) to implement the PPO algorithm but introduce the improvement of computing the Generalised Advantage Estimator (GAE) at every epoch. + + +## Examples Structure + +Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. ppo_atari.py). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml). + + +## Running the Examples + +You can execute the PPO algorithm on Atari environments by running the following command: + +```bash +python ppo_atari.py +``` + +You can execute the PPO algorithm on MuJoCo environments by running the following command: + +```bash +python ppo_mujoco.py +``` diff --git a/examples/ppo/config.yaml b/examples/ppo/config.yaml deleted file mode 100644 index d7840906c92..00000000000 --- a/examples/ppo/config.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# task and env -defaults: - - hydra/job_logging: disabled - -env: - env_name: PongNoFrameskip-v4 - env_task: "" - env_library: gym - frame_skip: 4 - num_envs: 8 - noop: 1 - reward_scaling: 1.0 - from_pixels: True - n_samples_stats: 1000 - device: cuda:0 - -# collector -collector: - frames_per_batch: 4096 - total_frames: 40_000_000 - collector_device: cuda:0 # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: ppo_pong_gym - log_interval: 10000 - -# Optim -optim: - device: cuda:0 - lr: 2.5e-4 - weight_decay: 0.0 - lr_scheduler: True - -# loss -loss: - gamma: 0.99 - mini_batch_size: 1024 - ppo_epochs: 10 - gae_lamdda: 0.95 - clip_epsilon: 0.1 - critic_coef: 0.5 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/ppo/config_atari.yaml b/examples/ppo/config_atari.yaml new file mode 100644 index 00000000000..6957fd9bddd --- /dev/null +++ b/examples/ppo/config_atari.yaml @@ -0,0 +1,36 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + num_envs: 8 + +# collector +collector: + frames_per_batch: 4096 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 2.5e-4 + eps: 1.0e-6 + weight_decay: 0.0 + max_grad_norm: 0.5 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 1024 + ppo_epochs: 3 + gae_lambda: 0.95 + clip_epsilon: 0.1 + anneal_clip_epsilon: True + critic_coef: 1.0 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/ppo/config_example2.yaml b/examples/ppo/config_example2.yaml deleted file mode 100644 index 465e6727119..00000000000 --- a/examples/ppo/config_example2.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# task and env -env: - env_name: HalfCheetah-v4 - env_task: "" - env_library: gym - frame_skip: 1 - num_envs: 1 - noop: 1 - reward_scaling: 1.0 - from_pixels: False - n_samples_stats: 3 - -# collector -collector: - frames_per_batch: 2048 - total_frames: 1_000_000 - collector_device: cuda # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: ppo_halfcheetah_gym - log_interval: 10000 - -# Optim -optim: - device: cuda - lr: 3e-4 - weight_decay: 0.0 - lr_scheduler: False - -# loss -loss: - gamma: 0.99 - mini_batch_size: 64 - ppo_epochs: 10 - gae_lamdda: 0.95 - clip_epsilon: 0.2 - critic_coef: 0.5 - entropy_coef: 0.0 - loss_critic_type: l2 diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml new file mode 100644 index 00000000000..1272c1f4bff --- /dev/null +++ b/examples/ppo/config_mujoco.yaml @@ -0,0 +1,33 @@ +# task and env +env: + env_name: HalfCheetah-v3 + +# collector +collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + +# logger +logger: + backend: wandb + exp_name: Mujoco_Schulman17 + test_interval: 1_000_000 + num_test_episodes: 5 + +# Optim +optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: False + +# loss +loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + anneal_clip_epsilon: False + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py deleted file mode 100644 index 7f532bc0c4d..00000000000 --- a/examples/ppo/ppo.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""PPO Example. - -This is a self-contained example of a PPO training script. - -Both state and pixel-based environments are supported. - -The helper functions are coded in the utils.py associated with this script. -""" -import hydra - - -@hydra.main(config_path=".", config_name="config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 - - import torch - import tqdm - from tensordict import TensorDict - from torchrl.envs.utils import ExplorationType, set_exploration_type - from utils import ( - make_collector, - make_data_buffer, - make_logger, - make_loss, - make_optim, - make_ppo_models, - make_test_env, - ) - - # Correct for frame_skip - cfg.collector.total_frames = cfg.collector.total_frames // cfg.env.frame_skip - cfg.collector.frames_per_batch = ( - cfg.collector.frames_per_batch // cfg.env.frame_skip - ) - mini_batch_size = cfg.loss.mini_batch_size = ( - cfg.loss.mini_batch_size // cfg.env.frame_skip - ) - - model_device = cfg.optim.device - actor, critic, critic_head = make_ppo_models(cfg) - - collector, state_dict = make_collector(cfg, policy=actor) - data_buffer = make_data_buffer(cfg) - loss_module, adv_module = make_loss( - cfg.loss, - actor_network=actor, - value_network=critic, - value_head=critic_head, - ) - optim = make_optim(cfg.optim, loss_module) - - batch_size = cfg.collector.total_frames * cfg.env.num_envs - num_mini_batches = batch_size // mini_batch_size - total_network_updates = ( - (cfg.collector.total_frames // batch_size) - * cfg.loss.ppo_epochs - * num_mini_batches - ) - - scheduler = None - if cfg.optim.lr_scheduler: - scheduler = torch.optim.lr_scheduler.LinearLR( - optim, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1 - ) - - logger = None - if cfg.logger.backend: - logger = make_logger(cfg.logger) - test_env = make_test_env(cfg.env, state_dict) - record_interval = cfg.logger.log_interval - pbar = tqdm.tqdm(total=cfg.collector.total_frames) - collected_frames = 0 - - # Main loop - r0 = None - l0 = None - frame_skip = cfg.env.frame_skip - ppo_epochs = cfg.loss.ppo_epochs - total_done = 0 - for data in collector: - - frames_in_batch = data.numel() - total_done += data.get(("next", "done")).sum() - collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) - - # Log end-of-episode accumulated rewards for training - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if logger is not None and len(episode_rewards) > 0: - logger.log_scalar( - "reward_training", episode_rewards.mean().item(), collected_frames - ) - - losses = TensorDict( - {}, batch_size=[ppo_epochs, -(frames_in_batch // -mini_batch_size)] - ) - for j in range(ppo_epochs): - # Compute GAE - with torch.no_grad(): - data = adv_module(data.to(model_device)).cpu() - - data_reshape = data.reshape(-1) - # Update the data buffer - data_buffer.extend(data_reshape) - - for i, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(model_device) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, i] = loss.detach() - - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=0.5 - ) - losses[j, i]["grad_norm"] = grad_norm - - optim.step() - if scheduler is not None: - scheduler.step() - optim.zero_grad() - - # Logging - if r0 is None: - r0 = data["next", "reward"].mean().item() - if l0 is None: - l0 = loss_sum.item() - pbar.set_description( - f"loss: {loss_sum.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['next', 'reward'].mean(): 4.4f} (init={r0: 4.4f})" - ) - if i + 1 != -(frames_in_batch // -mini_batch_size): - print( - f"Should have had {- (frames_in_batch // -mini_batch_size)} iters but had {i}." - ) - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) - if logger is not None: - for key, value in losses.items(): - logger.log_scalar(key, value.item(), collected_frames) - logger.log_scalar("total_done", total_done, collected_frames) - - collector.update_policy_weights_() - - # Test current policy - if ( - logger is not None - and (collected_frames - frames_in_batch) // record_interval - < collected_frames // record_interval - ): - - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - test_env.eval() - actor.eval() - # Generate a complete episode - td_test = test_env.rollout( - policy=actor, - max_steps=10_000_000, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - ).clone() - logger.log_scalar( - "reward_testing", - td_test["next", "reward"].sum().item(), - collected_frames, - ) - actor.train() - del td_test - - -if __name__ == "__main__": - main() diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py new file mode 100644 index 00000000000..eb2ce15ec5a --- /dev/null +++ b/examples/ppo/ppo_atari.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on Atari Environments. +""" + +import hydra + + +@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_atari import eval_model, make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # use end-of-life as done key + adv_module.set_keys(done="end-of-life", terminated="end-of-life") + loss_module.set_keys(done="end-of-life", terminated="end-of-life") + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, logger_name="ppo", experiment_name=exp_name + ) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) + + sampling_start = time.time() + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.optim.anneal_lr + cfg_optim_lr = cfg.optim.lr + cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.loss.clip_epsilon + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + cfg_optim_max_grad_norm = cfg.optim.max_grad_norm + cfg.loss.clip_epsilon = cfg_loss_clip_epsilon + losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) + + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "stop"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + training_start = time.time() + for j in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses_mean.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg_optim_lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/clip_epsilon": alpha * cfg_loss_clip_epsilon, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg_logger_num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/ppo_atari_pong.png b/examples/ppo/ppo_atari_pong.png deleted file mode 100644 index 639545f29e4..00000000000 Binary files a/examples/ppo/ppo_atari_pong.png and /dev/null differ diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py new file mode 100644 index 00000000000..37230fb33c6 --- /dev/null +++ b/examples/ppo/ppo_mujoco.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on MuJoCo Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_mujoco import eval_model, make_env, make_ppo_models + + # Define paper hyperparameters + device = "cpu" if not torch.cuda.device_count() else "cuda" + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size + total_network_updates = ( + (cfg.collector.total_frames // cfg.collector.frames_per_batch) + * cfg.loss.ppo_epochs + * num_mini_batches + ) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device), + sampler=sampler, + batch_size=cfg.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, logger_name="ppo", experiment_name=exp_name + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + sampling_start = time.time() + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.optim.anneal_lr + cfg_optim_lr = cfg.optim.lr + cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) + + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + training_start = time.time() + for j in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + for group in critic_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses_mean.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg_optim_lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/clip_epsilon": alpha * cfg_loss_clip_epsilon, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( + i * frames_in_batch + ) // cfg_logger_test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg_logger_num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/ppo_mujoco_halfcheetah.png b/examples/ppo/ppo_mujoco_halfcheetah.png deleted file mode 100644 index f168a5d40f3..00000000000 Binary files a/examples/ppo/ppo_mujoco_halfcheetah.png and /dev/null differ diff --git a/examples/ppo/training_curves.md b/examples/ppo/training_curves.md deleted file mode 100644 index d9f99eadb42..00000000000 --- a/examples/ppo/training_curves.md +++ /dev/null @@ -1,13 +0,0 @@ -# PPO Example Results - -## Atari Pong Environment - -We tested the Proximal Policy Optimization (PPO) algorithm on the Atari Pong environment. The hyperparameters used for the training are specified in the config.yaml file and are the same as those used in the original PPO paper (https://arxiv.org/abs/1707.06347). - -![ppo_atari_pong.png](ppo_atari_pong.png) - -## MuJoCo HalfCheetah Environment - -Additionally, we also tested the PPO algorithm on the MuJoCo HalfCheetah environment. The hyperparameters used for the training are specified in the config_example2.yaml file and are also the same as those used in the original PPO paper. However, this implementation uses a shared policy-value architecture. - -![ppo_mujoco_halfcheetah.png](ppo_mujoco_halfcheetah.png) diff --git a/examples/ppo/utils.py b/examples/ppo/utils.py deleted file mode 100644 index 47fcb992b36..00000000000 --- a/examples/ppo/utils.py +++ /dev/null @@ -1,497 +0,0 @@ -import torch.nn -import torch.optim -from tensordict.nn import TensorDictModule - -from torchrl.collectors import SyncDataCollector -from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - -from torchrl.data.tensor_specs import DiscreteBox -from torchrl.envs import ( - CatFrames, - CatTensors, - DoubleToFloat, - EnvCreator, - ExplorationType, - GrayScale, - NoopResetEnv, - ObservationNorm, - ParallelEnv, - Resize, - RewardScaling, - RewardSum, - StepCounter, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.modules import ( - ActorValueOperator, - ConvNet, - MLP, - NormalParamWrapper, - OneHotCategorical, - ProbabilisticActor, - TanhNormal, - ValueOperator, -) -from torchrl.objectives import ClipPPOLoss -from torchrl.objectives.value.advantages import GAE -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.envs import LIBS - - -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - - -# ==================================================================== -# Environment utils -# ----------------- - - -def make_base_env(env_cfg, from_pixels=None): - env_library = LIBS[env_cfg.env_library] - env_kwargs = { - "env_name": env_cfg.env_name, - "frame_skip": env_cfg.frame_skip, - "from_pixels": env_cfg.from_pixels - if from_pixels is None - else from_pixels, # for rendering - "pixels_only": False, - "device": env_cfg.device, - } - if env_library is DMControlEnv: - env_task = env_cfg.env_task - env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) - return env - - -def make_transformed_env(base_env, env_cfg): - if env_cfg.noop > 1: - base_env = TransformedEnv(env=base_env, transform=NoopResetEnv(env_cfg.noop)) - from_pixels = env_cfg.from_pixels - if from_pixels: - return make_transformed_env_pixels(base_env, env_cfg) - else: - return make_transformed_env_states(base_env, env_cfg) - - -def make_transformed_env_pixels(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env_library = LIBS[env_cfg.env_library] - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - env.append_transform(RewardScaling(0.0, reward_scaling)) - - double_to_float_list = [] - double_to_float_inv_list = [] - - env.append_transform(ToTensorImage()) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - - obs_norm = ObservationNorm(in_keys=["pixels"], standard_normal=True) - env.append_transform(obs_norm) - - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - double_to_float_inv_list += ["action"] # DMControl requires double-precision - - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) - return env - - -def make_transformed_env_states(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env_library = LIBS[env_cfg.env_library] - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - - env.append_transform(RewardScaling(0.0, reward_scaling)) - - double_to_float_list = [] - double_to_float_inv_list = [] - - # we concatenate all the state vectors - # even if there is a single tensor, it'll be renamed in "observation_vector" - selected_keys = [ - key for key in env.observation_spec.keys(True, True) if key != "pixels" - ] - out_key = "observation_vector" - env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - # obs_norm = ObservationNorm(in_keys=[out_key]) - # env.append_transform(obs_norm) - - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - double_to_float_inv_list += ["action"] # DMControl requires double-precision - double_to_float_list += ["observation_vector"] - else: - double_to_float_list += ["observation_vector"] - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) - return env - - -def make_parallel_env(env_cfg, state_dict): - num_envs = env_cfg.num_envs - env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg - ) - init_stats(env, 3, env_cfg.from_pixels) - env.load_state_dict(state_dict, strict=False) - return env - - -def get_stats(env_cfg): - env = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(env, env_cfg.n_samples_stats, env_cfg.from_pixels) - state_dict = env.state_dict() - for key in list(state_dict.keys()): - if key.endswith("loc") or key.endswith("scale"): - continue - del state_dict[key] - return state_dict - - -def init_stats(env, n_samples_stats, from_pixels): - for t in env.transform: - if isinstance(t, ObservationNorm): - if from_pixels: - t.init_stats( - n_samples_stats, - cat_dim=-4, - reduce_dim=tuple( - -i for i in range(1, len(t.parent.batch_size) + 5) - ), - keep_dims=(-1, -2, -3), - ) - else: - t.init_stats(n_samples_stats) - - -def make_test_env(env_cfg, state_dict): - env_cfg.num_envs = 1 - env = make_parallel_env(env_cfg, state_dict=state_dict) - return env - - -# ==================================================================== -# Collector and replay buffer -# --------------------------- - - -def make_collector(cfg, policy): - env_cfg = cfg.env - collector_cfg = cfg.collector - collector_class = SyncDataCollector - state_dict = get_stats(env_cfg) - collector = collector_class( - make_parallel_env(env_cfg, state_dict=state_dict), - policy, - frames_per_batch=collector_cfg.frames_per_batch, - total_frames=collector_cfg.total_frames, - device=collector_cfg.collector_device, - storing_device="cpu", - max_frames_per_traj=collector_cfg.max_frames_per_traj, - ) - return collector, state_dict - - -def make_data_buffer(cfg): - cfg_collector = cfg.collector - cfg_loss = cfg.loss - sampler = SamplerWithoutReplacement() - return TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg_collector.frames_per_batch), - sampler=sampler, - batch_size=cfg_loss.mini_batch_size, - ) - - -# ==================================================================== -# Model -# ----- -# -# We give one version of the model for learning from pixels, and one for state. -# TorchRL comes in handy at this point, as the high-level interactions with -# these models is unchanged, regardless of the modality. - - -def make_ppo_models(cfg): - - env_cfg = cfg.env - from_pixels = env_cfg.from_pixels - proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(proof_environment, 3, env_cfg.from_pixels) - - if not from_pixels: - # we must initialize the observation norm transform - # init_stats( - # proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels - # ) - common_module, policy_module, value_module = make_ppo_modules_state( - proof_environment - ) - else: - common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment - ) - - # Wrap modules in a single ActorCritic operator - actor_critic = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_module, - value_operator=value_module, - ).to(cfg.optim.device) - - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - - actor = actor_critic.get_policy_operator() - critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() - - return actor, critic, critic_head - - -def make_ppo_modules_state(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["observation_vector"].shape - - # Define distribution class and kwargs - continuous_actions = False - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - continuous_actions = True - num_outputs = proof_environment.action_spec.shape[-1] * 2 - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - "tanh_loc": False, - } - - # Define input keys - in_keys = ["observation_vector"] - shared_features_size = 256 - - # Define a shared Module and TensorDictModule - common_mlp = MLP( - in_features=input_shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=shared_features_size, - num_cells=[64, 64], - ) - common_module = TensorDictModule( - module=common_mlp, - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=shared_features_size, out_features=num_outputs, num_cells=[] - ) - if continuous_actions: - policy_net = NormalParamWrapper(policy_net) - - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["loc", "scale"] if continuous_actions else ["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["loc", "scale"] if continuous_actions else ["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP(in_features=shared_features_size, out_features=1, num_cells=[]) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -def make_ppo_modules_pixels(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["pixels"].shape - - # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - } - - # Define input keys - in_keys = ["pixels"] - - # Define a shared Module and TensorDictModule (CNN + MLP) - common_cnn = ConvNet( - activation_class=torch.nn.ReLU, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones(input_shape)) - common_mlp = MLP( - in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=512, - num_cells=[], - ) - common_mlp_output = common_mlp(common_cnn_output) - - # Define shared net as TensorDictModule - common_module = TensorDictModule( - module=torch.nn.Sequential(common_cnn, common_mlp), - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=common_mlp_output.shape[-1], - out_features=num_outputs, - activation_class=torch.nn.ReLU, - num_cells=[256], - ) - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - # safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP( - activation_class=torch.nn.ReLU, - in_features=common_mlp_output.shape[-1], - out_features=1, - num_cells=[256], - ) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -# ==================================================================== -# PPO Loss -# --------- - - -def make_advantage_module(loss_cfg, value_network): - advantage_module = GAE( - gamma=loss_cfg.gamma, - lmbda=loss_cfg.gae_lamdda, - value_network=value_network, - average_gae=True, - ) - return advantage_module - - -def make_loss(loss_cfg, actor_network, value_network, value_head): - advantage_module = make_advantage_module(loss_cfg, value_network) - loss_module = ClipPPOLoss( - actor=actor_network, - critic=value_head, - clip_epsilon=loss_cfg.clip_epsilon, - loss_critic_type=loss_cfg.loss_critic_type, - entropy_coef=loss_cfg.entropy_coef, - critic_coef=loss_cfg.critic_coef, - normalize_advantage=True, - ) - return loss_module, advantage_module - - -def make_optim(optim_cfg, loss_module): - optim = torch.optim.Adam( - loss_module.parameters(), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, - ) - return optim - - -# ==================================================================== -# Logging and recording -# --------------------- - - -def make_logger(logger_cfg): - exp_name = generate_exp_name("PPO", logger_cfg.exp_name) - logger_cfg.exp_name = exp_name - logger = get_logger(logger_cfg.backend, logger_name="ppo", experiment_name=exp_name) - return logger diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py new file mode 100644 index 00000000000..1355212ed70 --- /dev/null +++ b/examples/ppo/utils_atari.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + DoubleToFloat, + EndOfLifeTransform, + EnvCreator, + ExplorationType, + GrayScale, + GymEnv, + NoopResetEnv, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + env.append_transform(EndOfLifeTransform()) + return env + + +def make_parallel_env(env_name, num_envs, device, is_test=False): + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.low, + "max": proof_environment.action_spec.space.high, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, 1, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + + del proof_environment + + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards.append(reward.cpu()) + del td_test + return torch.cat(test_rewards, 0).mean() diff --git a/examples/ppo/utils_mujoco.py b/examples/ppo/utils_mujoco.py new file mode 100644 index 00000000000..8fa2a53fd92 --- /dev/null +++ b/examples/ppo/utils_mujoco.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + StepCounter, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu"): + env = GymEnv(env_name, device=device) + env = TransformedEnv(env) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + env.append_transform(VecNorm(in_keys=["observation"])) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.low, + "max": proof_environment.action_spec.space.high, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards.append(reward.cpu()) + del td_test + return torch.cat(test_rewards, 0).mean() diff --git a/examples/rlhf/.gitignore b/examples/rlhf/.gitignore new file mode 100644 index 00000000000..d8bad909a58 --- /dev/null +++ b/examples/rlhf/.gitignore @@ -0,0 +1,4 @@ +*.png +*.bin +*.pt +*.json diff --git a/examples/rlhf/README.md b/examples/rlhf/README.md new file mode 100644 index 00000000000..c4b0a261101 --- /dev/null +++ b/examples/rlhf/README.md @@ -0,0 +1,57 @@ +# RLHF example + +This example uses RLHF (Reinforcement Learning with Human Feedback) to train a +language model to summarize Reddit posts. + +## Getting started + +Make sure you have PyTorch>=2.0 installed. You can find installation instructions +[here](https://pytorch.org/get-started/locally/). + +From this directory, you can install extra requirements for running these +examples with + +```sh +pip install -r requirements.txt +``` + +## Training the models +### Training the transformer + +Once the data has been prepared, you can train the GPT model. + +```sh +python train.py +``` + +Default configuration can be found in `config/train.yaml`, and any option can +be overridden with command-line arguments, for example to run the training +script with a different batch size: + +```sh +python train.py --batch_size=128 +``` +> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` +> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback + +### Training the reward model + +Once you have completed supervised fine-tuning, copy the desired model +checkpoint to `./out` or update the config to point `model.name_or_path` at +the relevant checkpoint in the timestamped working directory created by Hydra. +You can then train the reward model with: + +```sh +python train_reward.py +``` + +### Training the final model with RLHF + +Once again, make sure you have either updated the configuration to point +`reward_model.name_or_path` at the relevant timestamped working directory, or +copy the checkpoint to `./out_reward`. +You can then train the final model by running + +```sh +python train_rlhf.py +``` diff --git a/examples/rlhf/config/train.yaml b/examples/rlhf/config/train.yaml new file mode 100644 index 00000000000..6d27088902f --- /dev/null +++ b/examples/rlhf/config/train.yaml @@ -0,0 +1,30 @@ +io: + eval_interval: 200 + log_interval: 50 + eval_iters: 100 +data: + batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint + out_dir: ./out + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +train: + grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 + max_iters: 5000 # total number of training iterations + gradient_accumulation_steps: 2 # used to simulate larger batch sizes + always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir + decay_lr: True # whether to decay the learning rate + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 1.0e-1 + betas: [0.9, 0.95] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 5000 # maximum number of iterations + eta_min: 1.0e-6 # minimum learning rate +sys: + device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks + dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/config/train_reward.yaml b/examples/rlhf/config/train_reward.yaml new file mode 100644 index 00000000000..a5523b75fe2 --- /dev/null +++ b/examples/rlhf/config/train_reward.yaml @@ -0,0 +1,32 @@ +io: + eval_interval: 200 + log_interval: 50 + eval_iters: 100 +data: + batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: ./out + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +reward_model: + out_dir: ./out_reward + init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward +train: + grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 + max_iters: 20000 # total number of training iterations + gradient_accumulation_steps: 2 # used to simulate larger batch sizes + always_save_checkpoint: False # if True, always save a checkpoint after each eval + decay_lr: False # whether to decay the learning rate + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 1.0e-1 + betas: [0.9, 0.95] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 20000 + eta_min: 1.0e-6 +sys: + device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks + dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/config/train_rlhf.yaml b/examples/rlhf/config/train_rlhf.yaml new file mode 100644 index 00000000000..024c239463e --- /dev/null +++ b/examples/rlhf/config/train_rlhf.yaml @@ -0,0 +1,39 @@ +io: + eval_interval: 6 + log_interval: 1 + eval_iters: 10 + logger: wandb +data: + batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 + num_workers: 1 +model: + name_or_path: ./out + out_dir: ./out_rlhf + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +reward_model: + name_or_path: ./out_reward +train: + grad_clip: 1.0 + max_epochs: 1000 # total number of training iterations + always_save_checkpoint: True # if True, always save a checkpoint after each eval + decay_lr: True + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 5.0e-5 + weight_decay: 0.0 # 01 + betas: [0.9, 0.999] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size + eta_min: 5.0e-6 + ppo: + episode_length: 50 + ppo_batch_size: 16 + ppo_num_epochs: 3 + num_rollouts_per_epoch: 32 +sys: + device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks + ref_device: cuda:1 # device of reference model + dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler + compile: False # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/data/__init__.py b/examples/rlhf/data/__init__.py new file mode 100644 index 00000000000..433c23452f2 --- /dev/null +++ b/examples/rlhf/data/__init__.py @@ -0,0 +1,3 @@ +from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr + +__all__ = ["get_prompt_dataloader_tldr"] diff --git a/examples/rlhf/models/__init__.py b/examples/rlhf/models/__init__.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/examples/rlhf/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py new file mode 100644 index 00000000000..3de34d55166 --- /dev/null +++ b/examples/rlhf/models/actor_critic.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator +from torchrl.modules.tensordict_module.common import VmapModule + +from .transformer import init_transformer + +__all__ = ["init_actor_critic"] + + +def init_actor_critic(model_cfg, sys_cfg): + + transformer_name_or_path = model_cfg.name_or_path + dropout = model_cfg.dropout + + device = sys_cfg.device + compile_model = sys_cfg.compile + base_model = init_transformer( + transformer_name_or_path, + dropout, + device, + as_tensordictmodule=False, + compile_model=compile_model, + inference=True, + ) + model = LMHeadActorValueOperator(base_model) + model.to(device) + model.eval() + actor = model.get_policy_operator() + critic = model.get_value_operator() + critic_head = model.get_value_head() + + return actor, VmapModule(critic), critic_head, base_model diff --git a/examples/rlhf/models/reward.py b/examples/rlhf/models/reward.py new file mode 100644 index 00000000000..da69e74ab4d --- /dev/null +++ b/examples/rlhf/models/reward.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import warnings + +import torch +from tensordict.nn import TensorDictModule + +from torchrl.modules.models.rlhf import GPT2RewardModel + + +def init_reward_model( + transformer_path=None, reward_model_path=None, device=None, compile_model=False +): + if transformer_path is None and reward_model_path is None: + warnings.warn( + "You did not provide a path to the reward model, a naive reward model will be used instead." + ) + model = GPT2RewardModel() + else: + if not ((transformer_path is None) ^ (reward_model_path is None)): + raise ValueError( + "Exactly one of transformer_path or reward_model_path should be specified." + ) + if transformer_path is not None: + model = GPT2RewardModel(transformer_path) + else: + model = GPT2RewardModel.from_pretrained(reward_model_path) + + model.to(device) + if compile_model: + print("Compiling the reward model...") + model = torch.compile(model) + + model = TensorDictModule( + model, + in_keys=["input_ids", "attention_mask"], + out_keys=["rewards", "end_scores"], + ) + return model diff --git a/examples/rlhf/models/transformer.py b/examples/rlhf/models/transformer.py new file mode 100644 index 00000000000..a33891a86a5 --- /dev/null +++ b/examples/rlhf/models/transformer.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from tensordict.nn import TensorDictModule +from transformers import GPT2LMHeadModel + + +def init_transformer( + name_or_path, + dropout, + device, + compile_model, + as_tensordictmodule=True, + inference=False, +): + model_kwargs = { + "resid_pdrop": dropout, + "embd_pdrop": dropout, + "attn_pdrop": dropout, + "summary_first_dropout": dropout, + } + model = GPT2LMHeadModel.from_pretrained( + name_or_path, return_dict=False, **model_kwargs + ) + model.to(device) + + if compile_model: + # TODO: logging instead of printing? + print("Compiling transformer model...") + model = torch.compile(model) + + if as_tensordictmodule: + model = TensorDictModule( + model, + in_keys={ + "input_ids": "input_ids", + "attention_mask": "attention_mask", + "labels": "labels", + }, + out_keys=["logits"] if inference else ["loss", "logits"], + ) + return model diff --git a/examples/rlhf/requirements.txt b/examples/rlhf/requirements.txt new file mode 100644 index 00000000000..9bff1b48453 --- /dev/null +++ b/examples/rlhf/requirements.txt @@ -0,0 +1,11 @@ +datasets +hydra-core +matplotlib +numpy +PyYAML +requests +tiktoken +tqdm +transformers +git+https://github.com/pytorch/rl +git+https://github.com/pytorch-labs/tensordict diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py new file mode 100644 index 00000000000..2e554f3edb9 --- /dev/null +++ b/examples/rlhf/train.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train the transformer model. Configurable via config/train.yaml, but any argument can +also be overridden at the command line. + +To run on a single GPU, example: +$ python train.py --batch_size=32 --compile=False +""" +import time + +import hydra +import torch +from models.transformer import init_transformer +from torch.optim.lr_scheduler import CosineAnnealingLR + +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.prompt import PromptData +from utils import get_file_logger, resolve_name_or_path, setup + + +def create_loss_estimator(eval_iters, ctx): + # helps estimate an arbitrarily accurate loss over either split using many batches + + @torch.no_grad() + def estimate_loss(model, dataloader): + model.eval() + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + batch.batch_size = [] + with ctx: + model(batch) + losses[k] = batch.loss.item() + model.train() + return losses.mean() + + return estimate_loss + + +@hydra.main(version_base="1.1", config_path="config", config_name="train") +def main(cfg): + loss_logger = get_file_logger("loss_logger", "transformer_loss_logger.log") + + data_cfg = cfg.data + model_cfg = cfg.model + train_cfg = cfg.train + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + out_dir = model_cfg.out_dir + + grad_clip = train_cfg.grad_clip + max_iters = train_cfg.max_iters + always_save_checkpoint = train_cfg.always_save_checkpoint + gradient_accumulation_steps = train_cfg.gradient_accumulation_steps + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device=device, dtype=dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="valid", + ) + + model = init_transformer( + resolve_name_or_path(model_cfg.name_or_path), + model_cfg.dropout, + device, + compile_model=compile_, + ) + optimizer = torch.optim.AdamW(model.parameters(), **train_cfg.optimizer) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) + estimate_loss = create_loss_estimator(eval_iters, ctx) + + best_val_loss = float("inf") + + t0 = time.time() + next_batch = next(train_loader) # fetch the very first batch + for it in range(1, max_iters + 1): + for _ in range(gradient_accumulation_steps): + batch = next_batch + # TODO: can we handle this better with a differently structured tensorclass? + batch.batch_size = [] + with ctx: + model(batch) + # immediately async prefetch next batch while model is doing the forward pass on the GPU + next_batch = next(train_loader) + # backward pass, with gradient scaling if training in fp16 + scaler.scale(batch.loss).backward() + + # clip the gradient + if grad_clip != 0.0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + # step the optimizer and scaler if training in fp16 + scaler.step(optimizer) + scaler.update() + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) + + # update learning rate + if scheduler is not None: + scheduler.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if it % eval_interval == 0: + # evaluate the loss on train/val sets and write checkpoints + train_loss = estimate_loss(model, train_loader) + val_loss = estimate_loss(model, val_loader) + msg = f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}" + print(msg) + loss_logger.info(msg) + if val_loss < best_val_loss or always_save_checkpoint: + best_val_loss = val_loss + if it > 0: + msg = f"saving checkpoint to {out_dir}" + print(msg) + loss_logger.info(msg) + model.module.save_pretrained(out_dir) + elif it % log_interval == 0: + # loss as float. note: this is a CPU-GPU sync point + loss = batch.loss.item() + msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms" + print(msg) + loss_logger.info(msg) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py new file mode 100644 index 00000000000..e16fbf45474 --- /dev/null +++ b/examples/rlhf/train_reward.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import time + +import hydra +import torch +from models.reward import init_reward_model +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.reward import PairwiseDataset +from utils import get_file_logger, resolve_name_or_path, setup + + +def _accuracy(chosen_end_scores, rejected_end_scores): + return ( + sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) + ).item() + + +# TODO: eliminate redundant repeated definition +# helps estimate an arbitrarily accurate loss over either split using many batches +def create_loss_estimator(eval_iters, ctx): + @torch.no_grad() + def estimate_loss(model, dataloader): + model.eval() + losses = torch.zeros(eval_iters) + accs = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + with ctx: + model(batch.chosen_data) + model(batch.rejected_data) + losses[k] = model.compute_reward_loss( + batch.chosen_data, batch.rejected_data + ).item() + accs[k] = _accuracy( + batch.chosen_data.end_scores, batch.rejected_data.end_scores + ) + model.train() + return losses.mean(), accs.mean() + + return estimate_loss + + +@hydra.main(version_base="1.1", config_path="config", config_name="train_reward") +def main(cfg): + loss_logger = get_file_logger("loss_logger", "reward_loss_logger.log") + + data_cfg = cfg.data + model_cfg = cfg.model + reward_model_cfg = cfg.reward_model + train_cfg = cfg.train + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + reward_out_dir = reward_model_cfg.out_dir + + max_iters = train_cfg.max_iters + always_save_checkpoint = train_cfg.always_save_checkpoint + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device=device, dtype=dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PairwiseDataset, + device, + dataset_name="CarperAI/openai_summarize_comparisons", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PairwiseDataset, + device, + dataset_name="CarperAI/openai_summarize_comparisons", + split="valid1", + ) + + if reward_model_cfg.init_from == "resume": + model = init_reward_model( + reward_model_path=resolve_name_or_path(reward_model_cfg.out_dir), + device=device, + compile_model=compile_, + ) + else: + model = init_reward_model( + transformer_path=resolve_name_or_path(model_cfg.name_or_path), + device=device, + compile_model=compile_, + ) + # Freeze the first 70% of the hidden layers of the reward model backbone + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + # ######## INIT TRAINING FUNCTIONS ######## + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], **train_cfg.optimizer + ) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + estimate_loss = create_loss_estimator(eval_iters, ctx) + + best_val_loss = float("inf") + + t0 = time.time() + for it in range(1, max_iters + 1): + batch = next(train_loader) + + with ctx: + model(batch.chosen_data) + model(batch.rejected_data) + optimizer.zero_grad(set_to_none=True) + loss = model.compute_reward_loss(batch.chosen_data, batch.rejected_data) + loss.backward() + optimizer.step() + if scheduler is not None: + scheduler.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if it % eval_interval == 0: + val_loss, val_acc = estimate_loss(model, val_loader) + train_loss, train_acc = estimate_loss(model, train_loader) + + msg = ( + f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}, " + f"{train_acc=:.4f}, {val_acc=:.4f}" + ) + print(msg) + loss_logger.info(msg) + if val_loss < best_val_loss or always_save_checkpoint: + best_val_loss = val_loss + if it > 0: + msg = f"saving checkpoint to {reward_out_dir}" + print(msg) + loss_logger.info(msg) + model.module.save_pretrained(reward_out_dir) + elif it % log_interval == 0: + loss = loss.item() + acc = _accuracy( + batch.chosen_data.end_scores, batch.rejected_data.end_scores + ) + msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms" + print(msg) + loss_logger.info(msg) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py new file mode 100644 index 00000000000..7dce72e7dd4 --- /dev/null +++ b/examples/rlhf/train_rlhf.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hydra +import torch +from models.actor_critic import init_actor_critic +from torchrl.data.rlhf.utils import AdaptiveKLController, RolloutFromModel + +from torchrl.record.loggers import get_logger + +from tqdm import tqdm + +from utils import ( + flatten_td, + freeze_layers, + get_prompt_loaders, + make_evaluator, + make_loss, + make_optimizer, + make_ref_model, + make_replay_buffer, + make_reward_model, + make_sub_replay_buffer, + resolve_name_or_path, + setup, + TrainLogger, +) + + +@hydra.main(version_base="1.1", config_path="config", config_name="train_rlhf") +def main(cfg): + + # ============ Retrieve config ============ # + ############################################# + + # make path absolute + cfg.model.name_or_path = resolve_name_or_path(cfg.model.name_or_path) + + # Get some constants: number of iters, grad clip... + batch_size = cfg.data.batch_size + num_rollouts_per_epoch = cfg.train.ppo.num_rollouts_per_epoch + collection_iters = num_rollouts_per_epoch // batch_size + + grad_clip = cfg.train.grad_clip + max_epochs = cfg.train.max_epochs + + ppo_batch_size = cfg.train.ppo.ppo_batch_size + ppo_num_epochs = cfg.train.ppo.ppo_num_epochs + + device = cfg.sys.device + + # ============ Instantiate utils ============ # + ############################################### + ctx = setup(cfg.sys) + + logger = get_logger( + logger_type=cfg.io.logger, logger_name="./log", experiment_name="torchrlhf-gpt2" + ) + + # =============== Dataloaders =============== # + ############################################### + # We use prompts to get generated data from the generative model + + train_prompt_loader, val_prompt_loader = get_prompt_loaders(cfg.data, cfg.sys) + + # ================= Models ================= # + ############################################## + # Actor (gen model) - critic (value predictor) + actor, critic, critic_head, model = init_actor_critic(cfg.model, cfg.sys) + # Freeze initial model to use as ref + ref_model = make_ref_model(model, sys_cfg=cfg.sys) + # Freeze layers of the model -- can be customized + freeze_layers(model) + + reward_model = make_reward_model(reward_model_cfg=cfg.reward_model, sys_cfg=cfg.sys) + + # ================= Loss and optimizer ================= # + ########################################################## + loss_fn, advantage = make_loss(actor, critic, critic_head) + + optimizer, lr_scheduler = make_optimizer(cfg.train, loss_fn) + + # ================= Replay buffer ================= # + ##################################################### + rb = make_replay_buffer(cfg.train.ppo, cfg.data) + + # ================= Data collector ================= # + ###################################################### + # + # Because we interact with HuggingFace's transformers models, + # using a Gym-like API (querying steps etc) introduces some + # extra code that we can spare. + # + kl_scheduler = AdaptiveKLController( + model, init_kl_coef=0.1, target=6, horizon=10000 + ) + rollout_from_model = RolloutFromModel( + model, + ref_model, + reward_model, + kl_scheduler=kl_scheduler, + num_steps=collection_iters, + ) + + # ================= Evaluation utils ================= # + ######################################################## + evaluator = make_evaluator( + ppo_cfg=cfg.train.ppo, + io_cfg=cfg.io, + model_cfg=cfg.model, + train_cfg=cfg.train, + val_prompt_loader=val_prompt_loader, + model=model, + ref_model=ref_model, + reward_model=reward_model, + ctx=ctx, + logger=logger, + ) + + # ================= Training loop ================= # + ##################################################### + + stats_logger = TrainLogger( + collection_iters, log_interval=cfg.io.log_interval, logger=logger + ) + pbar = tqdm(total=max_epochs * collection_iters) + for _ in range(max_epochs): + # ----------------- 1. Collect data, fill replay buffer ----------------- # + # it's possible we didn't fill the replay buffer in the last iteration if + # generation stopped early, so we empty first before repopulating + rb.empty() + for _ in range(collection_iters): + batch = next(train_prompt_loader) + td = rollout_from_model.rollout_from_data(batch) + with torch.no_grad(), ctx: + # TODO: moving this to within epoch + advantage(td) + rb.extend(flatten_td(td)) + stats_logger(td) + stats_logger.aggregate() + stats_logger.log() + + rollout_from_model.step_scheduler() + + # ----------------- 2. Feed model ----------------- # + for batch in rb: + rb_ppo = make_sub_replay_buffer(batch, batch_size=ppo_batch_size) + for _ in range(ppo_num_epochs): # PPO epochs + optimizer.zero_grad() + for minibatch in rb_ppo: # GO over RB + minibatch = minibatch.to(device, non_blocking=True) + with ctx: + loss_vals = loss_fn(minibatch) + loss_val = sum( + value + for key, value in loss_vals.items() + if key.startswith("loss") + ) + loss_val.backward() + torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), grad_clip) + optimizer.step() + if lr_scheduler is not None: + lr_scheduler.step() + pbar.update(1) + + # ----------------- 3. Possibly evaluate ----------------- # + evaluator.maybe_evaluate() + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py new file mode 100644 index 00000000000..198b2e72bcb --- /dev/null +++ b/examples/rlhf/utils.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import contextlib +import logging +from contextlib import nullcontext +from copy import deepcopy + +import torch +import torch._dynamo + +from hydra.utils import to_absolute_path +from models.reward import init_reward_model + +from tensordict import TensorDict +from torch.optim.lr_scheduler import CosineAnnealingLR + +from torchrl.data import ( + LazyTensorStorage, + RolloutFromModel, + TensorDictReplayBuffer, + TensorStorage, +) +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.prompt import PromptData +from torchrl.objectives import ClipPPOLoss +from torchrl.objectives.value import GAE + +from torchrl.record.loggers import Logger +from transformers import GenerationConfig, GPT2Tokenizer + + +class TestPromptLogger: + def __init__(self, batch, reward_model, logger, episode_length): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + test_rindex = batch.prompt_rindex[0] + test_prompt_ids = batch.input_ids[:1, :test_rindex] + test_label_ids = batch.input_ids[:1, test_rindex:] + test_prompt = tokenizer.decode(test_prompt_ids[0, :test_rindex].tolist()) + test_label = tokenizer.decode( + test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() + ) + _, test_label_reward = reward_model( + input_ids=batch.input_ids[:1], attention_mask=batch.attention_mask[:1] + ) + self.generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length + ) + self.test_prompt_ids = test_prompt_ids + self.reward_model = reward_model + self.tokenizer = tokenizer + self.test_label_reward = test_label_reward + self.test_rindex = test_rindex + self.test_prompt = test_prompt + self.test_label = test_label + self.logger = logger + + def log(self, model): + response_ids = model.generate( + input_ids=self.test_prompt_ids, generation_config=self.generation_config + ) + _, response_reward = self.reward_model( + input_ids=response_ids, + attention_mask=(response_ids != self.tokenizer.pad_token_id).to( + torch.int64 + ), + ) + reward = (response_reward - self.test_label_reward).item() + response_ids = response_ids[0, self.test_rindex :] + response = self.tokenizer.decode( + response_ids[response_ids != self.tokenizer.eos_token_id].tolist() + ) + string_to_write = ( + f"Query:\n{self.test_prompt}\n" + f"Response:\n{response}\n" + f"Actual response:\n{self.test_label}\n" + f"{reward=:4.4f}\n" + f"====================================================\n" + ) + self.logger.info(string_to_write) + + +class TrainLogger: + def __init__(self, size: int, log_interval: int, logger: Logger): + self.data = TensorDict({}, [size]) + self.counter = 0 + self.log_interval = log_interval + self.logger = logger + self.it = -1 + + def __call__(self, data): + done = data.get(("next", "done")) + td_done = data[done.view(data.shape)] + next_reward = td_done.get(("next", "reward_raw")) + next_kl = td_done.get(("next", "reward_kl")) + self.data[self.counter]["next_reward"] = next_reward.mean().cpu() + self.data[self.counter]["next_kl"] = next_kl.mean().cpu() + self.counter += 1 + + def aggregate(self): + result = {} + for key, item in self.data.items(): + result[key] = item.mean() + self.aggregated_data = TensorDict(result, []) + + def log(self): + self.it += 1 + if self.it % self.log_interval == 0: + for key, item in self.aggregated_data.items(): + self.logger.log_scalar(key, item) + + +class Evaluator: + def __init__( + self, + *, + reward_estimator, + model, + prompt_logger, + io_cfg, + val_reward_logger, + val_loader, + rlhf_out_dir, + always_save_checkpoint=False, + ctx=None, + logger=None, + ): + self.reward_estimator = reward_estimator + self.model = model + self.promp_logger = prompt_logger + self.io_cfg = io_cfg + self.eval_interval = io_cfg.eval_interval + self.log_interval = io_cfg.log_interval + self.eval_iters = io_cfg.eval_iters + if ctx is None: + ctx = contextlib.nullcontext() + self.ctx = ctx + self.val_reward_logger = val_reward_logger + self.val_loader = val_loader + self.always_save_checkpoint = always_save_checkpoint + self.rlhf_out_dir = rlhf_out_dir + self.logger = logger + + self.best_val_reward = -float("inf") + self.it = 0 + + def maybe_evaluate(self): + self.it += 1 + if self.it % self.eval_interval == 0: + with self.ctx: + val_reward = self.reward_estimator(self.model, self.val_loader) + self.prompt_logger.log(self.model) + self.val_reward_logger.info(f"VALID: {self.it=}: {val_reward=:.4f}") + self.logger.log_scalar({"val_reward": val_reward}, step=self.it) + # pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}") + if val_reward > self.best_val_reward: + self.best_val_reward = val_reward + if self.always_save_checkpoint: + if self.it > 0: + self.val_reward_logger.info( + f"saving checkpoint to {self.rlhf_out_dir}" + ) + self.model.save_pretrained(self.rlhf_out_dir) + + +class RewardEstimator: + """Create a class to estimate the reward via sampling. + + This class exposes a call method which, given a model and a dataloader, will + perform multiple rollouts using the model and data sampled from the dataloader then + average the accumulated rewards. + + For debugging purposes, we also generate responses to a fixed prompt so that the + quality of the model can be visually assessed during training. + + """ + + def __init__(self, eval_iters, episode_length, reward_model, ref_model): + """ + Args: + eval_iters (int): number of batches on which we would like to estimate reward + + episode_length (int): max number of generated new tokens + + reward_model (GPT2RewardModel): reward model + + ref_model (GPT2LMHeadModel): original transformer model that it is used to + correctly compute kl component of reward. + """ + self.ref_model = ref_model + self.reward_model = reward_model + self.eval_iters = eval_iters + self.episode_length = episode_length + + @torch.no_grad() + def __call__(self, model, dataloader): + rollout_from_model = RolloutFromModel( + model, + self.ref_model, + self.reward_model, + kl_coef=0, # disable KL for evaluation + max_new_tokens=self.episode_length, + ) + rewards = torch.zeros(self.eval_iters) + for k in range(self.eval_iters): + batch = next(dataloader) + td = rollout_from_model.rollout_from_data(batch) + rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item() + test_reward = rewards.mean() + + return test_reward + + +def resolve_name_or_path(name_or_path): + """Hydra changes the working directory, so we need to absolutify paths.""" + if not name_or_path: + return None + if name_or_path.startswith("./") or name_or_path.startswith("/"): + return to_absolute_path(name_or_path) + return name_or_path + + +def get_file_logger(name, filename, level=logging.DEBUG): + """ + Set up logger that will log to the given filename. + """ + logger = logging.getLogger(name) + handler = logging.FileHandler(filename) + handler.setFormatter( + # logging.Formatter("%(asctime)s, %(name)s %(levelname)s %(message)s") + logging.Formatter("%(asctime)s - %(message)s") + ) + logger.addHandler(handler) + logger.setLevel(level) + return logger + + +def setup(sys_cfg): + """ + Set manual seed, configure backend and autocasting. + """ + device = sys_cfg.device + dtype = sys_cfg.dtype + + torch.manual_seed(1337) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + torch._dynamo.config.cache_size_limit = 256 + + if "cuda" not in device: + return nullcontext() + + return torch.amp.autocast(device_type="cuda", dtype=getattr(torch, dtype)) + + +def flatten_td(td): + # our tensordict has shape [B, T] where B = batch_size and T = trajectory length + # some trajectories may have stopped (reached EOS) before generating T tokens + # this function truncates and concatenates the trajectories, resulting in a + # tensordict that has shape [N] where N <= B * T. + done = td["next", "done"] + mask = torch.zeros_like(done) + mask[..., 1:, :] = done[..., :-1, :] # shift by one + mask = ~mask.cumsum(-2).bool().squeeze() + return td[mask] + + +def make_evaluator( + ppo_cfg, + io_cfg, + model_cfg, + train_cfg, + val_prompt_loader, + model, + ref_model, + reward_model, + ctx, + logger, +): + query_logger = get_file_logger("query_logger", "rlhf_query_logger.log") + val_reward_logger = get_file_logger("val_reward_logger", "rlhf_valid_rewards.log") + episode_length = ppo_cfg.episode_length + rlhf_out_dir = model_cfg.out_dir + always_save_checkpoint = train_cfg.always_save_checkpoint + + test_prompt = next(val_prompt_loader) + prompt_logger = TestPromptLogger( + batch=test_prompt, + reward_model=reward_model, + logger=query_logger, + episode_length=episode_length, + ) + reward_estimator = RewardEstimator( + io_cfg.eval_iters, episode_length, reward_model, ref_model + ) + + evaluator = Evaluator( + reward_estimator=reward_estimator, + model=model, + prompt_logger=prompt_logger, + io_cfg=io_cfg, + val_reward_logger=val_reward_logger, + val_loader=val_prompt_loader, + rlhf_out_dir=rlhf_out_dir, + always_save_checkpoint=always_save_checkpoint, + ctx=ctx, + logger=logger, + ) + return evaluator + + +def make_replay_buffer(ppo_cfg, data_cfg): + return TensorDictReplayBuffer( + storage=LazyTensorStorage( + ppo_cfg.episode_length * ppo_cfg.num_rollouts_per_epoch + ), + batch_size=ppo_cfg.episode_length * data_cfg.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + ) + + +def get_prompt_loaders(data_cfg, sys_cfg): + train_prompt_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + sys_cfg.device, + dataset_name="CarperAI/openai_summarize_tldr", + split="train", + num_workers=data_cfg.num_workers, + ) + val_prompt_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + sys_cfg.device, + dataset_name="CarperAI/openai_summarize_tldr", + split="valid", + num_workers=data_cfg.num_workers, + ) + return train_prompt_loader, val_prompt_loader + + +def make_ref_model(model, sys_cfg): + device = sys_cfg.ref_device + ref_model = deepcopy(model).to(device) + ref_model.requires_grad_(False) + return ref_model + + +def freeze_layers(model): + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + +def make_reward_model(reward_model_cfg, sys_cfg): + device = sys_cfg.device + compile_model = sys_cfg.compile + reward_model = init_reward_model( + reward_model_path=resolve_name_or_path(reward_model_cfg.name_or_path), + device=device, + compile_model=compile_model, + ) + reward_model.eval() + reward_model.requires_grad_(False) + return reward_model + + +def make_loss(actor, critic, critic_head): + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True, shifted=True + ) + loss_fn = ClipPPOLoss(actor, critic_head) + return loss_fn, advantage + + +def make_optimizer(train_cfg, loss_fn): + optimizer = torch.optim.AdamW( + [p for p in loss_fn.parameters() if p.requires_grad], **train_cfg.optimizer + ) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + return optimizer, scheduler + + +def make_sub_replay_buffer(data, batch_size): + """A zero-copy sub-replay buffer.""" + # We expect some overhead due to the instantiation of the rb, storage and sampler + # but hopefully these shouldn't be as big as copying data. + # An optimized version of this would cache the rb, storage container and sampler and + # just rewire to the new data location. + storage = TensorStorage(data.exclude("index")) + rb = TensorDictReplayBuffer( + storage=storage, batch_size=batch_size, sampler=SamplerWithoutReplacement() + ) + return rb diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 22cba615d30..2d3425a2151 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,41 +1,41 @@ -# Environment +# environment and task env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-SAC" - library: gym - frame_skip: 1 - seed: 1 + exp_name: ${env.name}_SAC + library: gymnasium + max_episode_steps: 1000 + seed: 42 -# Collection +# collector collector: - total_frames: 1000000 - init_random_frames: 10000 + total_frames: 1_000_000 + init_random_frames: 25000 frames_per_batch: 1000 - max_frames_per_traj: 1000 init_env_steps: 1000 - async_collection: 1 collector_device: cpu env_per_collector: 1 - num_workers: 1 + reset_at_each_iter: False -# Replay Buffer +# replay buffer replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay + scratch_dir: ${env.exp_name}_${env.seed} -# Optimization -optimization: +# optim +optim: utd_ratio: 1.0 gamma: 0.99 - loss_function: smooth_l1 - lr: 3e-4 - weight_decay: 2e-4 - lr_scheduler: "" + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 + alpha_init: 1.0 + adam_eps: 1.0e-8 -# Algorithm +# network network: hidden_sizes: [256, 256] activation: relu @@ -43,7 +43,7 @@ network: scale_lb: 0.1 device: "cuda:0" -# Logging +# logging logger: backend: wandb mode: online diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 17b997cfda6..33b932ec42c 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -11,17 +11,20 @@ The helper functions are coded in the utils.py associated with this script. """ +import time + import hydra import numpy as np import torch import torch.cuda import tqdm - +from tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, make_environment, make_loss_module, @@ -35,6 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + # Create logger exp_name = generate_exp_name("SAC", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821 torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - # Create Environments + # Create environments train_env, eval_env = make_environment(cfg) - # Create Agent + + # Create agent model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) - # Create TD3 loss + # Create SAC loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Make Off-Policy Collector + # Create off-policy collector collector = make_collector(cfg, train_env, exploration_policy) - # Make Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir, device=device, ) - # Make Optimizers - optimizer = make_sac_optimizer(cfg, loss_module) - - rewards = [] - rewards_eval = [] + # Create optimizers + ( + optimizer_actor, + optimizer_critic, + optimizer_alpha, + ) = make_sac_optimizer(cfg, loss_module) # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector eval_iter = cfg.logger.eval_iter - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip - eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.env.max_episode_steps + sampling_start = time.time() for i, tensordict in enumerate(collector): - # update weights of the inference policy + sampling_time = time.time() - sampling_start + + # Update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() pbar.update(tensordict.numel()) - tensordict = tensordict.view(-1) + tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - # optimization steps + # Optimization steps + training_start = time.time() if collected_frames >= init_random_frames: - (actor_losses, q_losses, alpha_losses) = ([], [], []) - for _ in range(num_updates): - # sample from replay buffer + losses = TensorDict( + {}, + batch_size=[ + num_updates, + ], + ) + for i in range(num_updates): + # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() + # Compute loss loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] - loss = actor_loss + q_loss + alpha_loss - optimizer.zero_grad() - loss.backward() - optimizer.step() + # Update actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - alpha_losses.append(alpha_loss.item()) + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() - # update qnet_target params + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() + + losses[i] = loss_td.select( + "loss_actor", "loss_qvalue", "loss_alpha" + ).detach() + + # Update qnet_target params target_net_updater.step() - # update priority + # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - "alpha_loss": np.mean(alpha_losses), - "alpha": loss_td["alpha"], - "entropy": loss_td["entropy"], - } + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item() + metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() + metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() + metrics_to_log["train/alpha"] = loss_td["alpha"].item() + metrics_to_log["train/entropy"] = loss_td["entropy"].item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 4e84bde12c9..ebbee32057b 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import torch from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor @@ -6,8 +11,8 @@ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal @@ -20,16 +25,22 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) +def env_maker(task, device="cpu"): + with set_gym_backend("gym"): + return GymEnv( + task, + device=device, + ) -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms(env, max_episode_steps=1000): transformed_env = TransformedEnv( env, Compose( - RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat(in_keys=["observation"], in_keys_inv=[]), + InitTracker(), + StepCounter(max_episode_steps), + DoubleToFloat(), + RewardSum(), ), ) return transformed_env @@ -43,7 +54,7 @@ def make_environment(cfg): ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms(parallel_env) + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) eval_env = TransformedEnv( ParallelEnv( @@ -65,8 +76,8 @@ def make_collector(cfg, train_env, actor_model_explore): collector = SyncDataCollector( train_env, actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=cfg.collector.collector_device, ) @@ -114,17 +125,6 @@ def make_replay_buffer( # ----- -def get_activation(cfg): - if cfg.network.activation == "relu": - return nn.ReLU - elif cfg.network.activation == "tanh": - return nn.Tanh - elif cfg.network.activation == "leaky_relu": - return nn.LeakyReLU - else: - raise NotImplementedError - - def make_sac_agent(cfg, train_env, eval_env, device): """Make SAC agent.""" # Define Actor Network @@ -142,8 +142,8 @@ def make_sac_agent(cfg, train_env, eval_env, device): dist_class = TanhNormal dist_kwargs = { - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, "tanh_loc": False, } @@ -214,24 +214,68 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, delay_actor=False, delay_qvalue=True, + alpha_init=cfg.optim.alpha_init, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak - ) + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) return loss_module, target_net_updater +def split_critic_params(critic_params): + critic1_params = [] + critic2_params = [] + + for param in critic_params: + data1, data2 = param.data.chunk(2, dim=0) + critic1_params.append(nn.Parameter(data1)) + critic2_params.append(nn.Parameter(data2)) + return critic1_params, critic2_params + + def make_sac_optimizer(cfg, loss_module): - """Make SAC optimizer.""" - optimizer = optim.Adam( - loss_module.parameters(), - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) - return optimizer + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=3.0e-4, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 35a2d9f8b2f..4ef557ed50c 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,47 +1,50 @@ -# Environment +# task and env env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-TD3" - library: gym - frame_skip: 1 + exp_name: ${env.name}_TD3 + library: gymnasium seed: 42 + max_episode_steps: 1000 -# Collection +# collector collector: total_frames: 1000000 - init_random_frames: 10000 + init_random_frames: 25_000 init_env_steps: 1000 frames_per_batch: 1000 - max_frames_per_traj: 1000 - async_collection: 1 + reset_at_each_iter: False collector_device: cpu env_per_collector: 1 num_workers: 1 -# Replay Buffer +# replay buffer replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 + scratch_dir: ${env.exp_name}_${env.seed} -# Optimization -optimization: +# optim +optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 3e-4 - weight_decay: 2e-4 + lr: 3.0e-4 + weight_decay: 0.0 + adam_eps: 1e-4 batch_size: 256 target_update_polyak: 0.995 policy_update_delay: 2 + policy_noise: 0.2 + noise_clip: 0.5 -# Network +# network network: hidden_sizes: [256, 256] activation: relu device: "cuda:0" -# Logging +# logging logger: backend: wandb mode: online diff --git a/examples/td3/td3.py b/examples/td3/td3.py index f4d8707f404..7c9904f5300 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -11,8 +11,9 @@ The helper functions are coded in the utils.py associated with this script. """ -import hydra +import time +import hydra import numpy as np import torch import torch.cuda @@ -22,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, make_environment, make_loss_module, @@ -35,6 +37,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + # Create logger exp_name = generate_exp_name("TD3", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -45,140 +48,155 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) + # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - # Create Environments + # Create environments train_env, eval_env = make_environment(cfg) - # Create Agent + # Create agent model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) # Create TD3 loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Make Off-Policy Collector + # Create off-policy collector collector = make_collector(cfg, train_env, exploration_policy) - # Make Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir, device=device, ) - # Make Optimizers + # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) - rewards = [] - rewards_eval = [] - # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) - delayed_updates = cfg.optimization.policy_update_delay + delayed_updates = cfg.optim.policy_update_delay prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector - eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip + eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip + frames_per_batch = cfg.collector.frames_per_batch + update_counter = 0 - for i, tensordict in enumerate(collector): + sampling_start = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start exploration_policy.step(tensordict.numel()) - # update weights of the inference policy + + # Update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() pbar.update(tensordict.numel()) tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - # optimization steps + # Optimization steps + training_start = time.time() if collected_frames >= init_random_frames: ( actor_losses, q_losses, ) = ([], []) - for j in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() + for _ in range(num_updates): + + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 - loss_td = loss_module(sampled_tensordict) + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample().clone() - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] + # Compute loss + q_loss, *_ = loss_module.value_loss(sampled_tensordict) + # Update critic optimizer_critic.zero_grad() - update_actor = j % delayed_updates == 0 - q_loss.backward(retain_graph=update_actor) + q_loss.backward() optimizer_critic.step() q_losses.append(q_loss.item()) + # Update actor if update_actor: + actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() + actor_losses.append(actor_loss.item()) - # update qnet_target params + # Update target params target_net_updater.step() - # update priority + # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - } + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses) + if update_actor: + metrics_to_log["train/a_loss"] = np.mean(actor_losses) + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 529ad43138b..090529782fd 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -1,3 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import tempfile +from contextlib import nullcontext + import torch from torch import nn, optim @@ -10,10 +17,11 @@ EnvCreator, InitTracker, ParallelEnv, + RewardSum, + StepCounter, TransformedEnv, ) -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( AdditiveGaussianWrapper, @@ -33,17 +41,27 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) +def env_maker( + task, + device="cpu", + from_pixels=False, +): + with set_gym_backend("gym"): + return GymEnv( + task, + device=device, + from_pixels=from_pixels, + ) -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms(env, max_episode_steps): transformed_env = TransformedEnv( env, Compose( + StepCounter(max_steps=max_episode_steps), InitTracker(), - RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat(in_keys=["observation"], in_keys_inv=[]), + DoubleToFloat(), + RewardSum(), ), ) return transformed_env @@ -53,16 +71,18 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms(parallel_env) + train_env = apply_env_transforms( + parallel_env, max_episode_steps=cfg.env.max_episode_steps + ) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), ), train_env.transform.clone(), ) @@ -79,9 +99,10 @@ def make_collector(cfg, train_env, actor_model_explore): collector = SyncDataCollector( train_env, actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, + reset_at_each_iter=cfg.collector.reset_at_each_iter, device=cfg.collector.collector_device, ) collector.set_seed(cfg.env.seed) @@ -92,35 +113,40 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir="/tmp/", + buffer_scratch_dir=None, device="cpu", prefetch=3, ): - if prb: - replay_buffer = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - batch_size=batch_size, - ) - else: - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - batch_size=batch_size, - ) - return replay_buffer + with ( + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) + ) as scratch_dir: + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + return replay_buffer # ==================================================================== @@ -128,17 +154,6 @@ def make_replay_buffer( # ----- -def get_activation(cfg): - if cfg.network.activation == "relu": - return nn.ReLU - elif cfg.network.activation == "tanh": - return nn.Tanh - elif cfg.network.activation == "leaky_relu": - return nn.LeakyReLU - else: - raise NotImplementedError - - def make_td3_agent(cfg, train_env, eval_env, device): """Make TD3 agent.""" # Define Actor Network @@ -222,17 +237,18 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, delay_actor=True, delay_qvalue=True, + gamma=cfg.optim.gamma, action_spec=model[0][1].spec, + policy_noise=cfg.optim.policy_noise, + noise_clip=cfg.optim.noise_clip, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak - ) + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) return loss_module, target_net_updater @@ -241,11 +257,36 @@ def make_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optimization.lr, weight_decay=cfg.optimization.weight_decay + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) optimizer_critic = optim.Adam( critic_params, - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) return optimizer_actor, optimizer_critic + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/setup.cfg b/setup.cfg index 3f9ce9e3e4b..55e98280d3e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,9 +15,18 @@ per-file-ignores = test/smoke_test_deps.py: F401 test_*.py: F841, E731, E266 test/opengl_rendering.py: F401 + test/test_modules.py: F841, E731, E266, TOR101 + test/test_tensordictmodules.py: F841, E731, E266, TOR101 + torchrl/objectives/cql.py: TOR101 + torchrl/objectives/deprecated.py: TOR101 + torchrl/objectives/iql.py: TOR101 + torchrl/objectives/redq.py: TOR101 + torchrl/objectives/sac.py: TOR101 + torchrl/objectives/td3.py: TOR101 + torchrl/objectives/value/advantages.py: TOR101 exclude = venv -extend-select = B901, C401, C408, C409 +extend-select = B901, C401, C408, C409, TOR0, TOR1, TOR2 [pydocstyle] ;select = D417 # Missing argument descriptions in the docstring diff --git a/setup.py b/setup.py index 3723c1b1981..2d768354bb1 100644 --- a/setup.py +++ b/setup.py @@ -169,15 +169,16 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.1.1" + tensordict_dep = "tensordict>=0.2.0" if is_nightly: version = get_nightly_version() write_version_file(version) - print("Building wheel {}-{}".format(package_name, version)) - print(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}") else: version = get_version() + write_version_file(version) + print("Building wheel {}-{}".format(package_name, version)) + print(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}") pytorch_package_dep = _get_pytorch_version(is_nightly) print("-- PyTorch dependency:", pytorch_package_dep) @@ -188,6 +189,35 @@ def _main(argv): long_description = (this_directory / "README.md").read_text() sys.argv = [sys.argv[0]] + unknown + extra_requires = { + "atari": [ + "gym", + "atari-py", + "ale-py", + "gym[accept-rom-license]", + "pygame", + ], + "dm_control": ["dm_control"], + "gym_continuous": ["gymnasium", "mujoco"], + "rendering": ["moviepy"], + "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], + "utils": [ + "tensorboard", + "wandb", + "tqdm", + "hydra-core>=1.1", + "hydra-submitit-launcher", + "git", + ], + "checkpointing": [ + "torchsnapshot", + ], + "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1"], + } + extra_requires["all"] = set() + for key in list(extra_requires.keys()): + extra_requires["all"] = extra_requires["all"].union(extra_requires[key]) + extra_requires["all"] = sorted(extra_requires["all"]) setup( # Metadata name=name, @@ -212,37 +242,13 @@ def _main(argv): "cloudpickle", tensordict_dep, ], - extras_require={ - "atari": [ - "gym<=0.24", - "atari-py", - "ale-py", - "gym[accept-rom-license]", - "pygame", - ], - "dm_control": ["dm_control"], - "gym_continuous": ["mujoco-py", "mujoco"], - "rendering": ["moviepy"], - "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], - "utils": [ - "tensorboard", - "wandb", - "tqdm", - "hydra-core>=1.1", - "hydra-submitit-launcher", - "git", - ], - "checkpointing": [ - "torchsnapshot", - ], - "marl": ["vmas"], - }, + extras_require=extra_requires, zip_safe=False, classifiers=[ - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Development Status :: 4 - Beta", diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 85e76790c26..00758268593 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -17,10 +17,13 @@ import torch import torch.cuda -from tensordict import tensorclass +from tensordict import tensorclass, TensorDict from torchrl._utils import implement_for, seed_generator +from torchrl.data.utils import CloudpickleWrapper -from torchrl.envs import ObservationNorm +from torchrl.envs import MultiThreadedEnv, ObservationNorm +from torchrl.envs.batched_envs import ParallelEnv, SerialEnv +from torchrl.envs.libs.envpool import _has_envpool from torchrl.envs.libs.gym import _has_gym, GymEnv from torchrl.envs.transforms import ( Compose, @@ -28,7 +31,6 @@ ToTensorImage, TransformedEnv, ) -from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnv, ParallelEnv, SerialEnv # Specified for test_utils.py __version__ = "0.3" @@ -320,3 +322,143 @@ class MyClass: for key in td.keys(): MyClass.__annotations__[key] = torch.Tensor return tensorclass(MyClass) + + +def rollout_consistency_assertion( + rollout, *, done_key="done", observation_key="observation" +): + """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" + + done = rollout[:, :-1]["next", done_key].squeeze(-1) + # data resulting from step, when it's not done + r_not_done = rollout[:, :-1]["next"][~done] + # data resulting from step, when it's not done, after step_mdp + r_not_done_tp1 = rollout[:, 1:][~done] + torch.testing.assert_close( + r_not_done[observation_key], r_not_done_tp1[observation_key] + ) + + if not done.any(): + return + + # data resulting from step, when it's done + r_done = rollout[:, :-1]["next"][done] + # data resulting from step, when it's done, after step_mdp and reset + r_done_tp1 = rollout[:, 1:][done] + assert ( + (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1 + ).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) + + +def rand_reset(env): + """Generates a tensordict with reset keys that mimic the done spec. + + Values are drawn at random until at least one reset is present. + + """ + full_done_spec = env.full_done_spec + result = {} + for reset_key, list_of_done in zip(env.reset_keys, env.done_keys_groups): + val = full_done_spec[list_of_done[0]].rand() + while not val.any(): + val = full_done_spec[list_of_done[0]].rand() + result[reset_key] = val + # create a data structure that keeps the batch size of the nested specs + result = ( + full_done_spec.zero().update(result).exclude(*full_done_spec.keys(True, True)) + ) + return result + + +def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int): + index_batch_size = (0,) * (len(td.batch_size) - 1) + + # Check done and reset for root + observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1 + next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) + assert (td["next", "done"][observation_is_max]).all() + assert (~td["next", "done"][~observation_is_max]).all() + # Obs after done is 0 + assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all() + # Obs after not done is previous obs + assert ( + td["observation"][index_batch_size][1:][~next_is_done] + == td["next", "observation"][index_batch_size][:-1][~next_is_done] + ).all() + # Check observation and reward update with count action for root + action_is_count = td["action"].long().argmax(-1).to(torch.bool) + assert ( + td["next", "observation"][action_is_count] + == td["observation"][action_is_count] + 1 + ).all() + assert (td["next", "reward"][action_is_count] == 1).all() + # Check observation and reward do not update with no-count action for root + assert ( + td["next", "observation"][~action_is_count] + == td["observation"][~action_is_count] + ).all() + assert (td["next", "reward"][~action_is_count] == 0).all() + + # Check done and reset for nested_1 + observation_is_max = td["next", "nested_1", "observation"][..., 0] == max_steps + 1 + next_is_done = td["next", "nested_1", "done"][index_batch_size][:-1].squeeze(-1) + assert (td["next", "nested_1", "done"][observation_is_max]).all() + assert (~td["next", "nested_1", "done"][~observation_is_max]).all() + # Obs after done is 0 + assert ( + td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0 + ).all() + # Obs after not done is previous obs + assert ( + td["nested_1", "observation"][index_batch_size][1:][~next_is_done] + == td["next", "nested_1", "observation"][index_batch_size][:-1][~next_is_done] + ).all() + # Check observation and reward update with count action for nested_1 + action_is_count = td["nested_1"]["action"].to(torch.bool) + assert ( + td["next", "nested_1", "observation"][action_is_count] + == td["nested_1", "observation"][action_is_count] + 1 + ).all() + assert (td["next", "nested_1", "gift"][action_is_count] == 1).all() + # Check observation and reward do not update with no-count action for nested_1 + assert ( + td["next", "nested_1", "observation"][~action_is_count] + == td["nested_1", "observation"][~action_is_count] + ).all() + assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all() + + # Check done and reset for nested_2 + observation_is_max = td["next", "nested_2", "observation"][..., 0] == max_steps + 1 + next_is_done = td["next", "nested_2", "done"][index_batch_size][:-1].squeeze(-1) + assert (td["next", "nested_2", "done"][observation_is_max]).all() + assert (~td["next", "nested_2", "done"][~observation_is_max]).all() + # Obs after done is 0 + assert ( + td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0 + ).all() + # Obs after not done is previous obs + assert ( + td["nested_2", "observation"][index_batch_size][1:][~next_is_done] + == td["next", "nested_2", "observation"][index_batch_size][:-1][~next_is_done] + ).all() + # Check observation and reward update with count action for nested_2 + action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool) + assert ( + td["next", "nested_2", "observation"][action_is_count] + == td["nested_2", "observation"][action_is_count] + 1 + ).all() + assert (td["next", "nested_2", "reward"][action_is_count] == 1).all() + # Check observation and reward do not update with no-count action for nested_2 + assert ( + td["next", "nested_2", "observation"][~action_is_count] + == td["nested_2", "observation"][~action_is_count] + ).all() + assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all() + + +def decorate_thread_sub_func(func, num_threads): + def new_func(*args, **kwargs): + assert torch.get_num_threads() == num_threads + return func(*args, **kwargs) + + return CloudpickleWrapper(new_func) diff --git a/test/assets/generate.py b/test/assets/generate.py index 3e565bb6b69..75a87bb71b5 100644 --- a/test/assets/generate.py +++ b/test/assets/generate.py @@ -21,7 +21,7 @@ def generate_small_dataset(comparison=True): smalld = {} for key in list(d.keys()): if any(key.startswith(sub) for sub in ("train", "valid", "test")): - smalld[key] = Dataset.from_dict(d[key][:1000]) + smalld[key] = Dataset.from_dict(d[key][:50]) smalld = DatasetDict(smalld) if comparison: diff --git a/test/assets/openai_summarize_comparisons.zip b/test/assets/openai_summarize_comparisons.zip index 535a8849ec1..e48ba2a35ab 100644 Binary files a/test/assets/openai_summarize_comparisons.zip and b/test/assets/openai_summarize_comparisons.zip differ diff --git a/test/assets/openai_summarize_tldr.zip b/test/assets/openai_summarize_tldr.zip index f4fefd24b14..ad8daf11226 100644 Binary files a/test/assets/openai_summarize_tldr.zip and b/test/assets/openai_summarize_tldr.zip differ diff --git a/test/conftest.py b/test/conftest.py index c5cfdd680e7..048b9e6c49e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -54,7 +54,7 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(autouse=True) +@pytest.fixture(scope="session", autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -66,3 +66,23 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Deprecated call to `pkg_resources.declare_namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Using or importing the ABCs", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Please use `coo_matrix` from the `scipy.sparse` namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated", + ) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 6d5107fcc64..d71a0b5cbb3 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import NestedKey +from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, @@ -19,6 +19,7 @@ TensorSpec, UnboundedContinuousTensorSpec, ) +from torchrl.data.utils import consolidate_spec from torchrl.envs.common import EnvBase from torchrl.envs.model_based.common import ModelBasedEnvBase @@ -59,27 +60,38 @@ def __new__( *args, **kwargs, ): - for key, item in list(cls._output_spec["_observation_spec"].items()): - cls._output_spec["_observation_spec"][key] = item.to( + for key, item in list(cls._output_spec["full_observation_spec"].items()): + cls._output_spec["full_observation_spec"][key] = item.to( torch.get_default_dtype() ) - cls._output_spec["_reward_spec"] = cls._output_spec["_reward_spec"].to( - torch.get_default_dtype() - ) - if not isinstance(cls._output_spec["_reward_spec"], CompositeSpec): - cls._output_spec["_reward_spec"] = CompositeSpec( - reward=cls._output_spec["_reward_spec"], - shape=cls._output_spec["_reward_spec"].shape[:-1], + reward_spec = cls._output_spec["full_reward_spec"] + if isinstance(reward_spec, CompositeSpec): + reward_spec = CompositeSpec( + { + key: item.to(torch.get_default_dtype()) + for key, item in reward_spec.items(True, True) + }, + shape=reward_spec.shape, + device=reward_spec.device, ) - if not isinstance(cls._output_spec["_done_spec"], CompositeSpec): - cls._output_spec["_done_spec"] = CompositeSpec( - done=cls._output_spec["_done_spec"], - shape=cls._output_spec["_done_spec"].shape[:-1], + else: + reward_spec = reward_spec.to(torch.get_default_dtype()) + cls._output_spec["full_reward_spec"] = reward_spec + if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): + cls._output_spec["full_reward_spec"] = CompositeSpec( + reward=cls._output_spec["full_reward_spec"], + shape=cls._output_spec["full_reward_spec"].shape[:-1], ) - if not isinstance(cls._input_spec["_action_spec"], CompositeSpec): - cls._input_spec["_action_spec"] = CompositeSpec( - action=cls._input_spec["_action_spec"], - shape=cls._input_spec["_action_spec"].shape[:-1], + if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): + cls._output_spec["full_done_spec"] = CompositeSpec( + done=cls._output_spec["full_done_spec"].clone(), + terminated=cls._output_spec["full_done_spec"].clone(), + shape=cls._output_spec["full_done_spec"].shape[:-1], + ) + if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): + cls._input_spec["full_action_spec"] = CompositeSpec( + action=cls._input_spec["full_action_spec"], + shape=cls._input_spec["full_action_spec"].shape[:-1], ) return super().__new__(cls, *args, **kwargs) @@ -90,7 +102,7 @@ def __init__( **kwargs, ): super().__init__( - device="cpu", + device=kwargs.pop("device", "cpu"), dtype=torch.get_default_dtype(), ) self.set_seed(seed) @@ -162,25 +174,25 @@ def __new__( if state_spec is None: state_spec = CompositeSpec(shape=batch_size) input_spec = CompositeSpec( - _action_spec=action_spec, _state_spec=state_spec, shape=batch_size + full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size ) cls._output_spec = CompositeSpec(shape=batch_size) - cls._output_spec["_reward_spec"] = reward_spec - cls._output_spec["_done_spec"] = done_spec - cls._output_spec["_observation_spec"] = observation_spec + cls._output_spec["full_reward_spec"] = reward_spec + cls._output_spec["full_done_spec"] = done_spec + cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = input_spec - if not isinstance(cls._output_spec["_reward_spec"], CompositeSpec): - cls._output_spec["_reward_spec"] = CompositeSpec( - reward=cls._output_spec["_reward_spec"], shape=batch_size + if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): + cls._output_spec["full_reward_spec"] = CompositeSpec( + reward=cls._output_spec["full_reward_spec"], shape=batch_size ) - if not isinstance(cls._output_spec["_done_spec"], CompositeSpec): - cls._output_spec["_done_spec"] = CompositeSpec( - done=cls._output_spec["_done_spec"], shape=batch_size + if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): + cls._output_spec["full_done_spec"] = CompositeSpec( + done=cls._output_spec["full_done_spec"], shape=batch_size ) - if not isinstance(cls._input_spec["_action_spec"], CompositeSpec): - cls._input_spec["_action_spec"] = CompositeSpec( - action=cls._input_spec["_action_spec"], shape=batch_size + if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): + cls._input_spec["full_action_spec"] = CompositeSpec( + action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(*args, **kwargs) @@ -203,9 +215,10 @@ def _step(self, tensordict): done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict( { - "next": TensorDict( - {"reward": n, "done": done, "observation": n.clone()}, batch_size=[] - ) + "reward": n, + "done": done, + "terminated": done.clone(), + "observation": n.clone(), }, batch_size=[], ) @@ -218,7 +231,9 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: ) done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) - return TensorDict({"done": done, "observation": n}, []) + return TensorDict( + {"done": done, "terminated": done.clone(), "observation": n}, [] + ) def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: return self.step(tensordict) @@ -276,25 +291,25 @@ def __new__( if done_spec is None: done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) cls._output_spec = CompositeSpec(shape=batch_size) - cls._output_spec["_reward_spec"] = reward_spec - cls._output_spec["_done_spec"] = done_spec - cls._output_spec["_observation_spec"] = observation_spec + cls._output_spec["full_reward_spec"] = reward_spec + cls._output_spec["full_done_spec"] = done_spec + cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = CompositeSpec( - _action_spec=action_spec, - _state_spec=state_spec, + full_action_spec=action_spec, + full_state_spec=state_spec, shape=batch_size, ) - if not isinstance(cls._output_spec["_reward_spec"], CompositeSpec): - cls._output_spec["_reward_spec"] = CompositeSpec( - reward=cls._output_spec["_reward_spec"], shape=batch_size + if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): + cls._output_spec["full_reward_spec"] = CompositeSpec( + reward=cls._output_spec["full_reward_spec"], shape=batch_size ) - if not isinstance(cls._output_spec["_done_spec"], CompositeSpec): - cls._output_spec["_done_spec"] = CompositeSpec( - done=cls._output_spec["_done_spec"], shape=batch_size + if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): + cls._output_spec["full_done_spec"] = CompositeSpec( + done=cls._output_spec["full_done_spec"], shape=batch_size ) - if not isinstance(cls._input_spec["_action_spec"], CompositeSpec): - cls._input_spec["_action_spec"] = CompositeSpec( - action=cls._input_spec["_action_spec"], shape=batch_size + if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): + cls._input_spec["full_action_spec"] = CompositeSpec( + action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(cls, *args, **kwargs) @@ -337,13 +352,7 @@ def _step(self, tensordict): device=self.device, ) return TensorDict( - { - "next": TensorDict( - {"reward": n, "done": done, "observation": n}, - tensordict.batch_size, - device=self.device, - ) - }, + {"reward": n, "done": done, "terminated": done.clone(), "observation": n}, batch_size=tensordict.batch_size, device=self.device, ) @@ -376,7 +385,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: device=self.device, ) return TensorDict( - {"done": done, "observation": n}, + {"done": done, "terminated": done.clone(), "observation": n}, [ *leading_batch_size, *batch_size, @@ -430,16 +439,22 @@ def __new__( shape=batch_size, ) if action_spec is None: - action_spec_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) - action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) + if categorical_action_encoding: + action_spec_cls = DiscreteTensorSpec + action_spec = action_spec_cls(n=7, shape=batch_size) + else: + action_spec_cls = OneHotDiscreteTensorSpec + action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + reward_spec = CompositeSpec( + reward=UnboundedContinuousTensorSpec(shape=(1,)) + ) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = CompositeSpec( + terminated=DiscreteTensorSpec( + 2, dtype=torch.bool, shape=(*batch_size, 1) + ) + ) if state_spec is None: cls._out_key = "observation_orig" @@ -450,12 +465,12 @@ def __new__( shape=batch_size, ) cls._output_spec = CompositeSpec(shape=batch_size) - cls._output_spec["_reward_spec"] = reward_spec - cls._output_spec["_done_spec"] = done_spec - cls._output_spec["_observation_spec"] = observation_spec + cls._output_spec["full_reward_spec"] = reward_spec + cls._output_spec["full_done_spec"] = done_spec + cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = CompositeSpec( - _action_spec=action_spec, - _state_spec=state_spec, + full_action_spec=action_spec, + full_state_spec=state_spec, shape=batch_size, ) cls.from_pixels = from_pixels @@ -476,6 +491,9 @@ def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase: tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state)) tensordict = tensordict.set(self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) + tensordict.set( + "terminated", torch.zeros(*tensordict.shape, 1, dtype=torch.bool) + ) return tensordict def _step( @@ -496,11 +514,13 @@ def _step( done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) reward = done.any(-1).unsqueeze(-1) + # set done to False done = torch.zeros_like(done).all(-1).unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) - return tensordict.select().set("next", tensordict) + tensordict.set("terminated", done.clone()) + return tensordict class ContinuousActionVecMockEnv(_MockEnv): @@ -552,12 +572,12 @@ def __new__( shape=batch_size, ) cls._output_spec = CompositeSpec(shape=batch_size) - cls._output_spec["_reward_spec"] = reward_spec - cls._output_spec["_done_spec"] = done_spec - cls._output_spec["_observation_spec"] = observation_spec + cls._output_spec["full_reward_spec"] = reward_spec + cls._output_spec["full_done_spec"] = done_spec + cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = CompositeSpec( - _action_spec=action_spec, - _state_spec=state_spec, + full_action_spec=action_spec, + full_state_spec=state_spec, shape=batch_size, ) cls.from_pixels = from_pixels @@ -580,6 +600,9 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: # tensordict.set("next_" + self.out_key, self._get_out_obs(state)) # tensordict.set("next_" + self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) + tensordict.set( + "terminated", torch.zeros(*tensordict.shape, 1, dtype=torch.bool) + ) return tensordict def _step( @@ -602,7 +625,8 @@ def _step( done = reward = done.unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) - return tensordict.select().set("next", tensordict) + tensordict.set("terminated", done) + return tensordict def _obs_step(self, obs, a): return obs + a / self.maxstep @@ -662,7 +686,7 @@ def __new__( cls._out_key = "pixels_orig" state_spec = CompositeSpec( { - cls._out_key: observation_spec["pixels_orig"], + cls._out_key: observation_spec["pixels_orig"].clone(), }, shape=batch_size, ) @@ -807,9 +831,6 @@ def _get_out_obs(self, obs): def _get_in_obs(self, obs): obs = obs.diagonal(0, -1, -2) - # if any(dim == 1 for dim in obs.shape): - # print("squeezing obs", obs.shape) - # obs = obs.squeeze() return obs @@ -1023,6 +1044,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: source={ "observation": self.count.clone(), "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, }, batch_size=self.batch_size, device=self.device, @@ -1038,12 +1060,33 @@ def _step( source={ "observation": self.count.clone(), "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, "reward": torch.zeros_like(self.count, dtype=torch.float), }, batch_size=self.batch_size, device=self.device, ) - return tensordict.select().set("next", tensordict) + return tensordict + + +class IncrementingEnv(CountingEnv): + # Same as CountingEnv but always increments the count by 1 regardless of the action. + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 # The only difference with CountingEnv. + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict class NestedCountingEnv(CountingEnv): @@ -1119,20 +1162,11 @@ def __init__( ) if self.nested_done: + done_spec = self.full_done_spec.unsqueeze(-1).expand( + *self.batch_size, self.nested_dim + ) self.done_spec = CompositeSpec( - { - "data": CompositeSpec( - { - "done": self.done_spec.unsqueeze(-1).expand( - *self.batch_size, self.nested_dim, 1 - ) - }, - shape=( - *self.batch_size, - self.nested_dim, - ), - ) - }, + {"data": done_spec}, shape=self.batch_size, ) @@ -1146,10 +1180,15 @@ def _reset(self, tensordict): tensordict["_reset"] = tensordict["_reset"].sum(-2, dtype=torch.bool) td = super()._reset(tensordict) if self.nested_done: - td[self.done_key] = ( - td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) - ) - del td["done"] + for done_key in self.done_keys: + if isinstance(done_key, str): + done_key = (done_key,) + td[done_key] = ( + td[done_key[-1]] + .unsqueeze(-1) + .expand(*self.batch_size, self.nested_dim, 1) + ) + del td[done_key[-1]] if self.nested_obs_action: td["data", "states"] = ( td["observation"] @@ -1166,7 +1205,7 @@ def _step(self, td): td = td.clone() td["data"].batch_size = self.batch_size td[self.action_key] = td[self.action_key].max(-2)[0] - td_root = super()._step(td) + next_td = super()._step(td) if self.nested_obs_action: td[self.action_key] = ( td[self.action_key] @@ -1175,12 +1214,17 @@ def _step(self, td): ) if "data" in td.keys(): td["data"].batch_size = (*self.batch_size, self.nested_dim) - td = td_root["next"] + td = next_td if self.nested_done: - td[self.done_key] = ( - td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) - ) - del td["done"] + for done_key in self.done_keys: + if isinstance(done_key, str): + done_key = (done_key,) + td[done_key] = ( + td[done_key[-1]] + .unsqueeze(-1) + .expand(*self.batch_size, self.nested_dim, 1) + ) + del td[done_key[-1]] if self.nested_obs_action: td["data", "states"] = ( td["observation"] @@ -1195,7 +1239,7 @@ def _step(self, td): del td["reward"] if "data" in td.keys(): td["data"].batch_size = (*self.batch_size, self.nested_dim) - return td_root + return td class CountingBatchedEnv(EnvBase): @@ -1269,6 +1313,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: source={ "observation": self.count.clone(), "done": self.count > self.max_steps.view_as(self.count), + "terminated": self.count > self.max_steps.view_as(self.count), }, batch_size=self.batch_size, device=self.device, @@ -1284,9 +1329,455 @@ def _step( source={ "observation": self.count.clone(), "done": self.count > self.max_steps.unsqueeze(-1), + "terminated": self.count > self.max_steps.unsqueeze(-1), "reward": torch.zeros_like(self.count, dtype=torch.float), }, batch_size=self.batch_size, device=self.device, ) - return tensordict.select().set("next", tensordict) + return tensordict + + +class HeteroCountingEnvPolicy: + def __init__(self, full_action_spec: TensorSpec, count: bool = True): + self.full_action_spec = full_action_spec + self.count = count + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + action_td = self.full_action_spec.zero() + if self.count: + action_td.apply_(lambda x: x + 1) + return td.update(action_td) + + +class HeteroCountingEnv(EnvBase): + """A heterogeneous, counting Env.""" + + def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): + super().__init__(**kwargs) + self.n_nested_dim = 3 + self.max_steps = max_steps + self.start_val = start_val + + count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int) + count[:] = self.start_val + + self.register_buffer("count", count) + + obs_specs = [] + action_specs = [] + for index in range(self.n_nested_dim): + obs_specs.append(self.get_agent_obs_spec(index)) + action_specs.append(self.get_agent_action_spec(index)) + obs_specs = torch.stack(obs_specs, dim=0) + obs_spec_unlazy = consolidate_spec(obs_specs) + action_specs = torch.stack(action_specs, dim=0) + + self.unbatched_observation_spec = CompositeSpec( + lazy=obs_spec_unlazy, + state=UnboundedContinuousTensorSpec( + shape=( + 64, + 64, + 3, + ) + ), + ) + + self.unbatched_action_spec = CompositeSpec( + lazy=action_specs, + ) + self.unbatched_reward_spec = CompositeSpec( + { + "lazy": CompositeSpec( + { + "reward": UnboundedContinuousTensorSpec( + shape=(self.n_nested_dim, 1) + ) + }, + shape=(self.n_nested_dim,), + ) + } + ) + self.unbatched_done_spec = CompositeSpec( + { + "lazy": CompositeSpec( + { + "done": DiscreteTensorSpec( + n=2, + shape=(self.n_nested_dim, 1), + dtype=torch.bool, + ), + }, + shape=(self.n_nested_dim,), + ) + } + ) + + self.action_spec = self.unbatched_action_spec.expand( + *self.batch_size, *self.unbatched_action_spec.shape + ) + self.observation_spec = self.unbatched_observation_spec.expand( + *self.batch_size, *self.unbatched_observation_spec.shape + ) + self.reward_spec = self.unbatched_reward_spec.expand( + *self.batch_size, *self.unbatched_reward_spec.shape + ) + self.done_spec = self.unbatched_done_spec.expand( + *self.batch_size, *self.unbatched_done_spec.shape + ) + + def get_agent_obs_spec(self, i): + camera = BoundedTensorSpec(low=0, high=200, shape=(7, 7, 3)) + vector_3d = UnboundedContinuousTensorSpec(shape=(3,)) + vector_2d = UnboundedContinuousTensorSpec(shape=(2,)) + lidar = BoundedTensorSpec(low=0, high=5, shape=(8,)) + + tensor_0 = UnboundedContinuousTensorSpec(shape=(1,)) + tensor_1 = BoundedTensorSpec(low=0, high=3, shape=(1, 2)) + tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3)) + + if i == 0: + return CompositeSpec( + { + "camera": camera, + "lidar": lidar, + "vector": vector_3d, + "tensor_0": tensor_0, + } + ) + elif i == 1: + return CompositeSpec( + { + "camera": camera, + "lidar": lidar, + "vector": vector_2d, + "tensor_1": tensor_1, + } + ) + elif i == 2: + return CompositeSpec( + { + "camera": camera, + "vector": vector_2d, + "tensor_2": tensor_2, + } + ) + else: + raise ValueError(f"Index {i} undefined for index 3") + + def get_agent_action_spec(self, i): + action_3d = BoundedTensorSpec(low=-1, high=1, shape=(3,)) + action_2d = BoundedTensorSpec(low=-1, high=1, shape=(2,)) + + # Some have 2d action and some 3d + # TODO Introduce composite heterogeneous actions + if i == 0: + ret = action_3d + elif i == 1: + ret = action_2d + elif i == 2: + ret = action_2d + else: + raise ValueError(f"Index {i} undefined for index 3") + + return CompositeSpec({"action": ret}) + + def _reset( + self, + tensordict: TensorDictBase = None, + **kwargs, + ) -> TensorDictBase: + if tensordict is not None and self.reset_keys[0] in tensordict.keys(True): + _reset = tensordict.get(self.reset_keys[0]).squeeze(-1).any(-1) + self.count[_reset] = self.start_val + else: + self.count[:] = self.start_val + + reset_td = self.observation_spec.zero() + reset_td.apply_(lambda x: x + expand_right(self.count, x.shape)) + reset_td.update(self.output_spec["full_done_spec"].zero()) + + assert reset_td.batch_size == self.batch_size + + return reset_td + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + actions = torch.zeros_like(self.count.squeeze(-1), dtype=torch.bool) + for i in range(self.n_nested_dim): + action = tensordict["lazy"][..., i]["action"] + action = action[..., 0].to(torch.bool) + actions += action + + self.count += actions.unsqueeze(-1).to(torch.int) + + td = self.observation_spec.zero() + td.apply_(lambda x: x + expand_right(self.count, x.shape)) + td.update(self.output_spec["full_done_spec"].zero()) + td.update(self.output_spec["full_reward_spec"].zero()) + + assert td.batch_size == self.batch_size + for done_key in self.done_keys: + td[done_key] = expand_right( + self.count > self.max_steps, + self.full_done_spec[done_key].shape, + ) + + return td + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) + + +class MultiKeyCountingEnvPolicy: + def __init__( + self, + full_action_spec: TensorSpec, + count: bool = True, + deterministic: bool = False, + ): + if not deterministic and not count: + raise ValueError("Not counting policy is always deterministic") + + self.full_action_spec = full_action_spec + self.count = count + self.deterministic = deterministic + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + action_td = self.full_action_spec.zero() + if self.count: + if self.deterministic: + action_td["nested_1", "action"] += 1 + action_td["nested_2", "azione"] += 1 + action_td["action"][..., 1] = 1 + else: + # We choose an action at random + choice = torch.randint(0, 3, ()).item() + if choice == 0: + action_td["nested_1", "action"] += 1 + elif choice == 1: + action_td["nested_2", "azione"] += 1 + else: + action_td["action"][..., 1] = 1 + return td.update(action_td) + + +class MultiKeyCountingEnv(EnvBase): + def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): + super().__init__(**kwargs) + + self.max_steps = max_steps + self.start_val = start_val + self.nested_dim_1 = 3 + self.nested_dim_2 = 2 + + count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int) + count_nested_1 = torch.zeros( + (*self.batch_size, self.nested_dim_1, 1), + device=self.device, + dtype=torch.int, + ) + count_nested_2 = torch.zeros( + (*self.batch_size, self.nested_dim_2, 1), + device=self.device, + dtype=torch.int, + ) + + count[:] = self.start_val + count_nested_1[:] = self.start_val + count_nested_2[:] = self.start_val + + self.register_buffer("count", count) + self.register_buffer("count_nested_1", count_nested_1) + self.register_buffer("count_nested_2", count_nested_2) + + self.make_specs() + + self.action_spec = self.unbatched_action_spec.expand( + *self.batch_size, *self.unbatched_action_spec.shape + ) + self.observation_spec = self.unbatched_observation_spec.expand( + *self.batch_size, *self.unbatched_observation_spec.shape + ) + self.reward_spec = self.unbatched_reward_spec.expand( + *self.batch_size, *self.unbatched_reward_spec.shape + ) + self.done_spec = self.unbatched_done_spec.expand( + *self.batch_size, *self.unbatched_done_spec.shape + ) + + def make_specs(self): + self.unbatched_observation_spec = CompositeSpec( + nested_1=CompositeSpec( + observation=BoundedTensorSpec( + low=0, high=200, shape=(self.nested_dim_1, 3) + ), + shape=(self.nested_dim_1,), + ), + nested_2=CompositeSpec( + observation=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 2)), + shape=(self.nested_dim_2,), + ), + observation=UnboundedContinuousTensorSpec( + shape=( + 10, + 10, + 3, + ) + ), + ) + + self.unbatched_action_spec = CompositeSpec( + nested_1=CompositeSpec( + action=DiscreteTensorSpec(n=2, shape=(self.nested_dim_1,)), + shape=(self.nested_dim_1,), + ), + nested_2=CompositeSpec( + azione=BoundedTensorSpec(low=0, high=100, shape=(self.nested_dim_2, 1)), + shape=(self.nested_dim_2,), + ), + action=OneHotDiscreteTensorSpec(n=2), + ) + + self.unbatched_reward_spec = CompositeSpec( + nested_1=CompositeSpec( + gift=UnboundedContinuousTensorSpec(shape=(self.nested_dim_1, 1)), + shape=(self.nested_dim_1,), + ), + nested_2=CompositeSpec( + reward=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 1)), + shape=(self.nested_dim_2,), + ), + reward=UnboundedContinuousTensorSpec(shape=(1,)), + ) + + self.unbatched_done_spec = CompositeSpec( + nested_1=CompositeSpec( + done=DiscreteTensorSpec( + n=2, + shape=(self.nested_dim_1, 1), + dtype=torch.bool, + ), + terminated=DiscreteTensorSpec( + n=2, + shape=(self.nested_dim_1, 1), + dtype=torch.bool, + ), + shape=(self.nested_dim_1,), + ), + nested_2=CompositeSpec( + done=DiscreteTensorSpec( + n=2, + shape=(self.nested_dim_2, 1), + dtype=torch.bool, + ), + terminated=DiscreteTensorSpec( + n=2, + shape=(self.nested_dim_2, 1), + dtype=torch.bool, + ), + shape=(self.nested_dim_2,), + ), + done=DiscreteTensorSpec( + n=2, + shape=(1,), + dtype=torch.bool, + ), + terminated=DiscreteTensorSpec( + n=2, + shape=(1,), + dtype=torch.bool, + ), + ) + + def _reset( + self, + tensordict: TensorDictBase = None, + **kwargs, + ) -> TensorDictBase: + reset_all = False + if tensordict is not None: + _reset = tensordict.get("_reset", None) + if _reset is not None: + self.count[_reset.squeeze(-1)] = self.start_val + + _reset_nested_1 = tensordict.get(("nested_1", "_reset"), None) + if _reset_nested_1 is not None: + self.count_nested_1[_reset_nested_1.squeeze(-1)] = self.start_val + + _reset_nested_2 = tensordict.get(("nested_2", "_reset"), None) + if _reset_nested_2 is not None: + self.count_nested_2[_reset_nested_2.squeeze(-1)] = self.start_val + + if _reset is None and _reset_nested_1 is None and _reset_nested_2 is None: + reset_all = True + + if tensordict is None or reset_all: + self.count[:] = self.start_val + self.count_nested_1[:] = self.start_val + self.count_nested_2[:] = self.start_val + + reset_td = self.observation_spec.zero() + reset_td["observation"] += expand_right( + self.count, reset_td["observation"].shape + ) + reset_td["nested_1", "observation"] += expand_right( + self.count_nested_1, reset_td["nested_1", "observation"].shape + ) + reset_td["nested_2", "observation"] += expand_right( + self.count_nested_2, reset_td["nested_2", "observation"].shape + ) + + reset_td.update(self.output_spec["full_done_spec"].zero()) + + assert reset_td.batch_size == self.batch_size + + return reset_td + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + + # Each action has a corresponding reward, done, and observation + reward = self.output_spec["full_reward_spec"].zero() + done = self.output_spec["full_done_spec"].zero() + td = self.observation_spec.zero() + + one_hot_action = tensordict["action"] + one_hot_action = one_hot_action.long().argmax(-1).unsqueeze(-1) + reward["reward"] += one_hot_action.to(torch.float) + self.count += one_hot_action.to(torch.int) + td["observation"] += expand_right(self.count, td["observation"].shape) + done["done"] = self.count > self.max_steps + done["terminated"] = self.count > self.max_steps + + discrete_action = tensordict["nested_1"]["action"].unsqueeze(-1) + reward["nested_1"]["gift"] += discrete_action.to(torch.float) + self.count_nested_1 += discrete_action.to(torch.int) + td["nested_1", "observation"] += expand_right( + self.count_nested_1, td["nested_1", "observation"].shape + ) + done["nested_1", "done"] = self.count_nested_1 > self.max_steps + done["nested_1", "terminated"] = self.count_nested_1 > self.max_steps + + continuous_action = tensordict["nested_2"]["azione"] + reward["nested_2"]["reward"] += continuous_action.to(torch.float) + self.count_nested_2 += continuous_action.to(torch.bool) + td["nested_2", "observation"] += expand_right( + self.count_nested_2, td["nested_2", "observation"].shape + ) + done["nested_2", "done"] = self.count_nested_2 > self.max_steps + done["nested_2", "terminated"] = self.count_nested_2 > self.max_steps + + td.update(done) + td.update(reward) + + assert td.batch_size == self.batch_size + return td + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) diff --git a/test/test_actors.py b/test/test_actors.py index ee358cbe25a..8b432e9ac21 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -51,7 +51,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3): env = NestedCountingEnv(nested_dim=nested_dim) action_spec = BoundedTensorSpec( - shape=torch.Size((nested_dim, n_actions)), maximum=1, minimum=-1 + shape=torch.Size((nested_dim, n_actions)), high=1, low=-1 ) policy_module = TensorDictModule( nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] @@ -112,7 +112,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions=3): env = NestedCountingEnv(nested_dim=nested_dim) action_spec = BoundedTensorSpec( - shape=torch.Size((nested_dim, n_actions)), maximum=1, minimum=-1 + shape=torch.Size((nested_dim, n_actions)), high=1, low=-1 ) actor_net = nn.Sequential( nn.Linear(1, 2), @@ -217,7 +217,7 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5): env = NestedCountingEnv( nest_obs_action=nested_action, batch_size=batch_size, nested_dim=nested_dim ) - action_spec = env._input_spec["_action_spec"] + action_spec = env._input_spec["full_action_spec"] leaf_action_spec = env.action_spec space_str, spec = _process_action_space_spec(None, action_spec) @@ -266,6 +266,20 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5): spec=action_spec, ) + @pytest.mark.parametrize( + "action_space, var_nums, expected_action", + ( + ("multi_one_hot", [2, 2, 2], [1, 0, 1, 0, 1, 0]), + ("multi_one_hot", [2, 4], [1, 0, 1, 0, 0, 0]), + ), + ) + def test_qvalue_module_multi_one_hot(self, action_space, var_nums, expected_action): + module = QValueModule(action_space=action_space, var_nums=var_nums) + in_values = torch.tensor([1.0, 0, 2, 0, 1, 0]) + action, values, chosen_action_value = module(in_values) + assert (torch.tensor(expected_action, dtype=torch.long) == action).all() + assert (values == in_values).all() + @pytest.mark.parametrize( "action_space, expected_action", ( @@ -599,6 +613,39 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action assert values.shape == in_values.shape assert (values == in_values).all() + @pytest.mark.parametrize("action_space", ["categorical", "one-hot"]) + @pytest.mark.parametrize("action_n", [2, 3, 4, 5]) + def test_qvalue_mask(self, action_space, action_n): + torch.manual_seed(0) + shape = (3, 4, 3, action_n) + action_values = torch.randn(size=shape) + td = TensorDict({"action_value": action_values}, [3]) + module = QValueModule( + action_space=action_space, + action_value_key="action_value", + action_mask_key="action_mask", + ) + with pytest.raises(KeyError, match="Action mask key "): + module(td) + + action_mask = torch.randint(high=2, size=shape).to(torch.bool) + while not action_mask.any(dim=-1).all() or action_mask.all(): + action_mask = torch.randint(high=2, size=shape).to(torch.bool) + + td.set("action_mask", action_mask) + module(td) + new_action_values = td.get("action_value") + + assert (new_action_values[~action_mask] != action_values[~action_mask]).all() + assert (new_action_values[action_mask] == action_values[action_mask]).all() + assert (td.get("chosen_action_value") > torch.finfo(torch.float).min).all() + + if action_space == "one-hot": + assert (td.get("action")[action_mask]).any() + assert not (td.get("action")[~action_mask]).any() + else: + assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all() + @pytest.mark.parametrize("device", get_default_devices()) def test_value_based_policy(device): @@ -756,10 +803,11 @@ def test_actorcritic(device): @pytest.mark.skipif(not _has_transformers, reason="missing dependencies") @pytest.mark.parametrize("device", get_default_devices()) def test_lmhead_actorvalueoperator(device): - from transformers import AutoModelForCausalLM + from transformers import AutoModelForCausalLM, GPT2Config - base_model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict=False) - aco = LMHeadActorValueOperator(base_model) + config = GPT2Config(return_dict=False) + base_model = AutoModelForCausalLM.from_config(config).eval() + aco = LMHeadActorValueOperator(base_model).to(device) # check common assert aco.module[0][0].module is base_model.transformer @@ -786,7 +834,8 @@ def test_lmhead_actorvalueoperator(device): batch_size=[ 4, ], - ).to(device) + device=device, + ) td_total = aco(td.clone()) policy_op = aco.get_policy_operator() td_policy = policy_op(td.clone()) diff --git a/test/test_collector.py b/test/test_collector.py index 607f09d635a..3d71bb09a8c 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -4,12 +4,20 @@ # LICENSE file in the root directory of this source tree. import argparse + import sys import numpy as np import pytest import torch -from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED +from _utils_internal import ( + check_rollout_consistency_multikey_env, + decorate_thread_sub_func, + generate_seeds, + get_default_devices, + PENDULUM_VERSIONED, + PONG_VERSIONED, +) from mocking_classes import ( ContinuousActionVecMockEnv, CountingBatchedEnv, @@ -19,11 +27,16 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, + HeteroCountingEnv, + HeteroCountingEnvPolicy, MockSerialEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, NestedCountingEnv, ) from tensordict.nn import TensorDictModule from tensordict.tensordict import assert_allclose_td, TensorDict + from torch import nn from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector @@ -43,15 +56,21 @@ SerialEnv, StepCounter, ) -from torchrl.envs.libs.gym import _has_gym, GymEnv +from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend from torchrl.envs.transforms import TransformedEnv, VecNorm +from torchrl.envs.utils import ( + _aggregate_resets, + _replace_last, + check_env_specs, + PARTIAL_MISSING_ERR, +) from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule # torch.set_default_dtype(torch.double) -_os_is_windows = sys.platform == "win32" -_python_is_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10 -_python_is_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7 -_os_is_osx = sys.platform == "darwin" +IS_WINDOWS = sys.platform == "win32" +IS_OSX = sys.platform == "darwin" +PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10 +PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7 class WrappablePolicy(nn.Module): @@ -159,7 +178,7 @@ def _is_consistent_device_type( @pytest.mark.skipif( - _os_is_windows and _python_is_3_10, + IS_WINDOWS and PYTHON_3_10, reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection", ) @pytest.mark.parametrize("num_env", [2]) @@ -174,7 +193,7 @@ def test_output_device_consistency( ) and not torch.cuda.is_available(): pytest.skip("cuda is not available") - if _os_is_windows and _python_is_3_7: + if IS_WINDOWS and PYTHON_3_7: if device == "cuda" and policy_device == "cuda" and device is None: pytest.skip( "BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows" @@ -246,6 +265,7 @@ def env_fn(seed): assert d.names[-1] == "time" ccollector.shutdown() + del ccollector @pytest.mark.parametrize("num_env", [1, 2]) @@ -321,7 +341,10 @@ def test_collector_env_reset(): torch.manual_seed(0) def make_env(): - return TransformedEnv(GymEnv(PONG_VERSIONED, frame_skip=4), StepCounter()) + # This is currently necessary as the methods in GymWrapper may have mismatching backend + # versions. + with set_gym_backend(gym_backend()): + return TransformedEnv(GymEnv(PONG_VERSIONED, frame_skip=4), StepCounter()) env = SerialEnv(2, make_env) # env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4)) @@ -346,52 +369,52 @@ def make_env(): assert _data["next", "reward"].sum(-2).min() == -21 -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize("env_name", ["vec"]) -def test_collector_done_persist(num_env, env_name, seed=5): - if num_env == 1: - - def env_fn(seed): - env = MockSerialEnv(device="cpu") - env.set_seed(seed) - return env - - else: - - def env_fn(seed): - def make_env(seed): - env = MockSerialEnv(device="cpu") - env.set_seed(seed) - return env - - env = ParallelEnv( - num_workers=num_env, - create_env_fn=make_env, - create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - allow_step_when_done=True, - ) - env.set_seed(seed) - return env - - policy = make_policy(env_name) - - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=200 * num_env, - max_frames_per_traj=2000, - total_frames=20000, - device="cpu", - reset_when_done=False, - ) - for _, d in enumerate(collector): # noqa - break - - assert (d["done"].sum(-2) >= 1).all() - assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 - - del collector +# Deprecated reset_when_done +# @pytest.mark.parametrize("num_env", [1, 2]) +# @pytest.mark.parametrize("env_name", ["vec"]) +# def test_collector_done_persist(num_env, env_name, seed=5): +# if num_env == 1: +# +# def env_fn(seed): +# env = MockSerialEnv(device="cpu") +# env.set_seed(seed) +# return env +# +# else: +# +# def env_fn(seed): +# def make_env(seed): +# env = MockSerialEnv(device="cpu") +# env.set_seed(seed) +# return env +# +# env = ParallelEnv( +# num_workers=num_env, +# create_env_fn=make_env, +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# env.set_seed(seed) +# return env +# +# policy = make_policy(env_name) +# +# collector = SyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=200 * num_env, +# max_frames_per_traj=2000, +# total_frames=20000, +# device="cpu", +# reset_when_done=False, +# ) +# for _, d in enumerate(collector): # noqa +# break +# +# assert (d["done"].sum(-2) >= 1).all() +# assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 +# +# del collector @pytest.mark.parametrize("frames_per_batch", [200, 10]) @@ -417,7 +440,6 @@ def make_env(seed): num_workers=num_env, create_env_fn=make_env, create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - allow_step_when_done=True, ) env.set_seed(seed) return env @@ -497,7 +519,7 @@ def test_collector_batch_size( num_env, env_name, seed=100, num_workers=2, frames_per_batch=20 ): """Tests that there are 'frames_per_batch' frames in each batch of a collection.""" - if num_env == 3 and _os_is_windows: + if num_env == 3 and IS_WINDOWS: pytest.skip("Test timeout (> 10 min) on CI pipeline Windows machine with GPU") if num_env == 1: @@ -550,6 +572,7 @@ def env_fn(): break assert b.names[-1] == "time" ccollector.shutdown() + del ccollector @pytest.mark.parametrize("num_env", [1, 2]) @@ -626,10 +649,11 @@ def env_fn(seed): # Get a single rollout with dummypolicy env = env_fn(seed) - rollout1a = env.rollout(policy=policy, max_steps=20, auto_reset=True) + env = TransformedEnv(env, StepCounter(20)) + rollout1a = env.rollout(policy=policy, max_steps=50, auto_reset=True) env.set_seed(seed) - rollout1b = env.rollout(policy=policy, max_steps=20, auto_reset=True) - rollout2 = env.rollout(policy=policy, max_steps=20, auto_reset=True) + rollout1b = env.rollout(policy=policy, max_steps=50, auto_reset=True) + rollout2 = env.rollout(policy=policy, max_steps=50, auto_reset=True) assert_allclose_td(rollout1a, rollout1b) with pytest.raises(AssertionError): assert_allclose_td(rollout1a, rollout2) @@ -656,7 +680,6 @@ def env_fn(seed): assert ( rollout1a.batch_size == b1.batch_size ), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}" - assert_allclose_td(rollout1a, b1.select(*rollout1a.keys(True, True))) collector.shutdown() @@ -767,6 +790,11 @@ def env_fn(seed): assert_allclose_td(data10, data20) +@pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Nested spawned multiprocessed is currently failing in python 3.11. " + "See https://github.com/python/cpython/pull/108568 for info and fix.", +) @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @pytest.mark.parametrize("static_seed", [True, False]) def test_collector_vecnorm_envcreator(static_seed): @@ -826,7 +854,7 @@ def test_collector_vecnorm_envcreator(static_seed): td4 = s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]["td"].clone() assert (td3 == td4).all() assert (td1 != td4).any() - + c.shutdown() del c @@ -934,6 +962,7 @@ def make_env(): break collector.shutdown() dummy_env.close() + del collector @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @@ -1025,6 +1054,11 @@ def test_collector_output_keys( } if split_trajs: keys.add(("collector", "mask")) + + keys.add(("next", "terminated")) + keys.add("terminated") + keys.add(("next", "truncated")) + keys.add("truncated") b = next(iter(collector)) assert set(b.keys(True)) == keys @@ -1036,12 +1070,7 @@ def test_collector_output_keys( @pytest.mark.parametrize("storing_device", ["cuda", "cpu"]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") def test_collector_device_combinations(device, storing_device): - if ( - _os_is_windows - and _python_is_3_10 - and storing_device == "cuda" - and device == "cuda" - ): + if IS_WINDOWS and PYTHON_3_10 and storing_device == "cuda" and device == "cuda": pytest.skip("Windows fatal exception: access violation in torch.storage") def env_fn(seed): @@ -1108,6 +1137,7 @@ def env_fn(seed): batch = next(collector.iterator()) assert batch.device == torch.device(storing_device) collector.shutdown() + del collector @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @@ -1167,6 +1197,8 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): assert isinstance(collector.policy, TensorDictModule) assert collector.policy.out_keys == out_keys assert collector.policy.module is policy + collector.shutdown() + del collector def test_no_wrap_compatible_module(self, collector_class, env_maker): policy = TensorDictCompatiblePolicy( @@ -1191,6 +1223,8 @@ def test_no_wrap_compatible_module(self, collector_class, env_maker): assert isinstance(collector.policy, TensorDictCompatiblePolicy) assert collector.policy.out_keys == ["action"] assert collector.policy is policy + collector.shutdown() + del collector def test_auto_wrap_error(self, collector_class, env_maker): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) @@ -1250,6 +1284,8 @@ def test_initial_obs_consistency(env_class, seed=1): expected_1 = torch.cat([arange_0, arange_0, arange]) expected = torch.stack([expected_0, expected_1]) assert torch.allclose(obs, expected.to(obs.dtype)) + collector.shutdown() + del collector def weight_reset(m): @@ -1257,7 +1293,7 @@ def weight_reset(m): m.reset_parameters() -@pytest.mark.skipif(_os_is_osx, reason="Queue.qsize does not work on osx.") +@pytest.mark.skipif(IS_OSX, reason="Queue.qsize does not work on osx.") class TestPreemptiveThreshold: @pytest.mark.parametrize("env_name", ["conv", "vec"]) def test_sync_collector_interruptor_mechanism(self, env_name, seed=100): @@ -1285,6 +1321,8 @@ def env_fn(seed): for batch in collector: assert batch["collector"]["traj_ids"][0] != -1 assert batch["collector"]["traj_ids"][1] == -1 + collector.shutdown() + del collector @pytest.mark.parametrize( "env_name", ["vec"] @@ -1309,8 +1347,8 @@ def env_fn(seed): frames_per_batch=frames_per_batch, init_random_frames=-1, reset_at_each_iter=False, - devices="cpu", - storing_devices="cpu", + devices=get_default_devices()[0], + storing_devices=get_default_devices()[0], split_trajs=False, preemptive_threshold=0.0, # stop after one iteration ) @@ -1319,6 +1357,8 @@ def env_fn(seed): trajectory_ids = batch["collector"]["traj_ids"] trajectory_ids_mask = trajectory_ids != -1 # valid frames mask assert trajectory_ids[trajectory_ids_mask].numel() < frames_per_batch + collector.shutdown() + del collector def test_maxframes_error(): @@ -1340,11 +1380,13 @@ def test_reset_heterogeneous_envs(): env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2)) env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3)) env = SerialEnv(2, [env1, env2]) - c = SyncDataCollector( + collector = SyncDataCollector( env, RandomPolicy(env.action_spec), total_frames=10_000, frames_per_batch=1000 ) - for data in c: # noqa: B007 + for data in collector: # noqa: B007 break + collector.shutdown() + del collector assert ( data[0]["next", "truncated"].squeeze() == torch.tensor([False, True]).repeat(250)[:500] @@ -1355,6 +1397,26 @@ def test_reset_heterogeneous_envs(): ).all() +def test_policy_with_mask(): + env = CountingBatchedEnv(start_val=torch.tensor(10), max_steps=torch.tensor(1e5)) + + def policy(td): + obs = td.get("observation") + # This policy cannot work with obs all 0s + if not obs.any(): + raise AssertionError + action = obs.clone() + td.set("action", action) + return td + + collector = SyncDataCollector( + env, policy=policy, frames_per_batch=10, total_frames=20 + ) + for _ in collector: + break + collector.shutdown() + + class TestNestedEnvsCollector: def test_multi_collector_nested_env_consistency(self, seed=1): env = NestedCountingEnv() @@ -1367,7 +1429,7 @@ def test_multi_collector_nested_env_consistency(self, seed=1): policy=policy, frames_per_batch=20, total_frames=100, - device="cpu", + device=get_default_devices()[0], ) for i, d in enumerate(ccollector): if i == 0: @@ -1380,13 +1442,14 @@ def test_multi_collector_nested_env_consistency(self, seed=1): with pytest.raises(AssertionError): assert_allclose_td(c1, c2) ccollector.shutdown() + del ccollector ccollector = MultiSyncDataCollector( create_env_fn=[env_fn], policy=policy, frames_per_batch=20, total_frames=100, - device="cpu", + device=get_default_devices()[0], ) for i, d in enumerate(ccollector): if i == 0: @@ -1399,7 +1462,7 @@ def test_multi_collector_nested_env_consistency(self, seed=1): with pytest.raises(AssertionError): assert_allclose_td(d1, d2) ccollector.shutdown() - + del ccollector assert_allclose_td(c1, d1) assert_allclose_td(c2, d2) @@ -1426,12 +1489,13 @@ def test_collector_nested_env_combinations( policy=policy, frames_per_batch=frames_per_batch, total_frames=100, - device="cpu", + device=get_default_devices()[0], ) for _td in ccollector: break ccollector.shutdown() + del ccollector @pytest.mark.parametrize("batch_size", [(), (5,), (5, 2)]) def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): @@ -1447,13 +1511,13 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): policy=policy, frames_per_batch=frames_per_batch, total_frames=100, - device="cpu", + device=get_default_devices()[0], ) for _td in ccollector: break ccollector.shutdown() - + del ccollector assert ("data", "reward") not in _td.keys(True) assert _td.batch_size == (*batch_size, frames_per_batch // prod(batch_size)) assert _td["data"].batch_size == ( @@ -1468,6 +1532,171 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): ) +class TestHetEnvsCollector: + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + @pytest.mark.parametrize("frames_per_batch", [4, 8, 16]) + def test_collector_het_env(self, batch_size, frames_per_batch, seed=1, max_steps=4): + batch_size = torch.Size(batch_size) + env = HeteroCountingEnv(max_steps=max_steps - 1, batch_size=batch_size) + torch.manual_seed(seed) + device = get_default_devices()[0] + policy = HeteroCountingEnvPolicy(env.input_spec["full_action_spec"]) + ccollector = SyncDataCollector( + create_env_fn=env, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device=device, + ) + + for _td in ccollector: + break + ccollector.shutdown() + collected_frames = frames_per_batch // batch_size.numel() + + for i in range(env.n_nested_dim): + if collected_frames >= max_steps: + agent_obs = _td["lazy"][(0,) * len(batch_size)][..., i][f"tensor_{i}"] + for _ in range(i + 1): + agent_obs = agent_obs.mean(-1) + assert ( + agent_obs + == torch.arange(max_steps, device=device).repeat( + collected_frames // max_steps + ) + ).all() # Check reset worked + assert (_td["lazy"][..., i]["action"] == 1).all() + del ccollector + + def test_multi_collector_het_env_consistency( + self, seed=1, frames_per_batch=20, batch_dim=10 + ): + env = HeteroCountingEnv(max_steps=3, batch_size=(batch_dim,)) + torch.manual_seed(seed) + env_fn = lambda: TransformedEnv(env, InitTracker()) + check_env_specs(env_fn(), return_contiguous=False) + policy = HeteroCountingEnvPolicy(env.input_spec["full_action_spec"]) + + ccollector = MultiaSyncDataCollector( + create_env_fn=[env_fn], + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device=get_default_devices()[0], + ) + for i, d in enumerate(ccollector): + if i == 0: + c1 = d + elif i == 1: + c2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(c1, c2) + ccollector.shutdown() + + ccollector = MultiSyncDataCollector( + create_env_fn=[env_fn], + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device=get_default_devices()[0], + ) + for i, d in enumerate(ccollector): + if i == 0: + d1 = d + elif i == 1: + d2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(d1, d2) + ccollector.shutdown() + del ccollector + + assert_allclose_td(c1, d1) + assert_allclose_td(c2, d2) + + +class TestMultiKeyEnvsCollector: + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + @pytest.mark.parametrize("frames_per_batch", [4, 8, 16]) + @pytest.mark.parametrize("max_steps", [2, 3]) + def test_collector(self, batch_size, frames_per_batch, max_steps, seed=1): + env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps) + torch.manual_seed(seed) + policy = MultiKeyCountingEnvPolicy(env.input_spec["full_action_spec"]) + ccollector = SyncDataCollector( + create_env_fn=env, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device=get_default_devices()[0], + ) + + for _td in ccollector: + break + ccollector.shutdown() + del ccollector + for done_key in env.done_keys: + assert _replace_last(done_key, "_reset") not in _td.keys(True, True) + check_rollout_consistency_multikey_env(_td, max_steps=max_steps) + + def test_multi_collector_consistency( + self, seed=1, frames_per_batch=20, batch_dim=10 + ): + env = MultiKeyCountingEnv(batch_size=(batch_dim,)) + env_fn = lambda: env + torch.manual_seed(seed) + policy = MultiKeyCountingEnvPolicy( + env.input_spec["full_action_spec"], deterministic=True + ) + + ccollector = MultiaSyncDataCollector( + create_env_fn=[env_fn], + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device=get_default_devices()[0], + ) + for i, d in enumerate(ccollector): + if i == 0: + c1 = d + elif i == 1: + c2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(c1, c2) + ccollector.shutdown() + + ccollector = MultiSyncDataCollector( + create_env_fn=[env_fn], + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device=get_default_devices()[0], + ) + for i, d in enumerate(ccollector): + if i == 0: + d1 = d + elif i == 1: + d2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(d1, d2) + ccollector.shutdown() + del ccollector + + assert_allclose_td(c1, d1) + assert_allclose_td(c2, d2) + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda") class TestUpdateParams: class DummyEnv(EnvBase): @@ -1492,11 +1721,9 @@ def _step( self.state += action return TensorDict( { - "next": { - "state": self.state.clone(), - "reward": self.reward_spec.zero(), - "done": self.done_spec.zero(), - } + "state": self.state.clone(), + "reward": self.reward_spec.zero(), + **self.full_done_spec.zero(), }, self.batch_size, ) @@ -1568,6 +1795,269 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device): assert (data["action"] == 3).all() finally: col.shutdown() + del col + + +class TestAggregateReset: + def test_aggregate_reset_to_root(self): + # simple + td = TensorDict({"_reset": torch.zeros((1,), dtype=torch.bool)}, []) + assert _aggregate_resets(td).shape == () + # td with batch size + td = TensorDict({"_reset": torch.zeros((1,), dtype=torch.bool)}, [1]) + assert _aggregate_resets(td).shape == (1,) + td = TensorDict({"_reset": torch.zeros((1, 2), dtype=torch.bool)}, [1]) + assert _aggregate_resets(td).shape == (1,) + # nested td + td = TensorDict( + { + "_reset": torch.zeros((1,), dtype=torch.bool), + "a": {"_reset": torch.zeros((1, 2), dtype=torch.bool)}, + }, + [1], + ) + assert _aggregate_resets(td).shape == (1,) + # nested td with greater number of dims + td = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.zeros((1, 2), dtype=torch.bool)}, + }, + [1, 2], + ) + # test reduction + assert _aggregate_resets(td).shape == (1, 2) + td = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.ones((1, 2), dtype=torch.bool)}, + }, + [1, 2], + ) + # test reduction, partial + assert _aggregate_resets(td).shape == (1, 2) + td = TensorDict( + { + "_reset": torch.tensor([True, False]).view(1, 2), + "a": {"_reset": torch.zeros((1, 2), dtype=torch.bool)}, + }, + [1, 2], + ) + assert (_aggregate_resets(td) == torch.tensor([True, False]).view(1, 2)).all() + # with a stack + td0 = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.ones((1, 2), dtype=torch.bool)}, + "b": {"c": torch.randn(1, 2)}, + }, + [1, 2], + ) + td1 = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.ones((1, 2), dtype=torch.bool)}, + "b": {"c": torch.randn(1, 2, 5)}, + }, + [1, 2], + ) + td = torch.stack([td0, td1], 0) + assert _aggregate_resets(td).all() + + def test_aggregate_reset_to_root_keys(self): + # simple + td = TensorDict({"_reset": torch.zeros((1,), dtype=torch.bool)}, []) + assert _aggregate_resets(td, reset_keys=["_reset"]).shape == () + # td with batch size + td = TensorDict({"_reset": torch.zeros((1,), dtype=torch.bool)}, [1]) + assert _aggregate_resets(td, reset_keys=["_reset"]).shape == (1,) + td = TensorDict({"_reset": torch.zeros((1, 2), dtype=torch.bool)}, [1]) + assert _aggregate_resets(td, reset_keys=["_reset"]).shape == (1,) + # nested td + td = TensorDict( + { + "_reset": torch.zeros((1,), dtype=torch.bool), + "a": {"_reset": torch.zeros((1, 2), dtype=torch.bool)}, + }, + [1], + ) + assert _aggregate_resets(td, reset_keys=["_reset", ("a", "_reset")]).shape == ( + 1, + ) + # nested td with greater number of dims + td = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.zeros((1, 2), dtype=torch.bool)}, + }, + [1, 2], + ) + # test reduction + assert _aggregate_resets(td, reset_keys=["_reset", ("a", "_reset")]).shape == ( + 1, + 2, + ) + td = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.ones((1, 2), dtype=torch.bool)}, + }, + [1, 2], + ) + assert _aggregate_resets(td, reset_keys=["_reset", ("a", "_reset")]).all() + # test reduction, partial + assert _aggregate_resets(td, reset_keys=["_reset", ("a", "_reset")]).shape == ( + 1, + 2, + ) + td = TensorDict( + { + "_reset": torch.tensor( + [True, False], + ).view(1, 2), + "a": {"_reset": torch.zeros((1, 2), dtype=torch.bool)}, + }, + [1, 2], + ) + assert ( + _aggregate_resets(td, reset_keys=["_reset", ("a", "_reset")]) + == torch.tensor([True, False]).view(1, 2) + ).all() + # with a stack + td0 = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.ones((1, 2), dtype=torch.bool)}, + "b": {"c": torch.randn(1, 2)}, + }, + [1, 2], + ) + td1 = TensorDict( + { + "_reset": torch.zeros( + (1, 2), + dtype=torch.bool, + ), + "a": {"_reset": torch.ones((1, 2), dtype=torch.bool)}, + "b": {"c": torch.randn(1, 2, 5)}, + }, + [1, 2], + ) + td = torch.stack([td0, td1], 0) + assert _aggregate_resets(td, reset_keys=["_reset", ("a", "_reset")]).all() + + def test_aggregate_reset_to_root_errors(self): + # the order matters: if the first or another key is missing, the ValueError is raised at a different line + with pytest.raises(ValueError, match=PARTIAL_MISSING_ERR): + _aggregate_resets( + TensorDict({"_reset": False}, []), + reset_keys=["_reset", ("another", "_reset")], + ) + with pytest.raises(ValueError, match=PARTIAL_MISSING_ERR): + _aggregate_resets( + TensorDict({"_reset": False}, []), + reset_keys=[("another", "_reset"), "_reset"], + ) + + +@pytest.mark.parametrize( + "collector_class", + [MultiSyncDataCollector, MultiaSyncDataCollector, SyncDataCollector], +) +def test_collector_reloading(collector_class): + def make_env(): + return ContinuousActionVecMockEnv() + + dummy_env = make_env() + obs_spec = dummy_env.observation_spec["observation"] + policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) + policy = Actor(policy_module, spec=dummy_env.action_spec) + policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) + + collector_kwargs = { + "create_env_fn": make_env, + "policy": policy_explore, + "frames_per_batch": 30, + "total_frames": 90, + } + if collector_class is not SyncDataCollector: + collector_kwargs["create_env_fn"] = [ + collector_kwargs["create_env_fn"] for _ in range(3) + ] + + collector = collector_class(**collector_kwargs) + for i, _ in enumerate(collector): + if i == 3: + break + collector_frames = collector._frames + collector_iter = collector._iter + collector_state_dict = collector.state_dict() + collector.shutdown() + + collector = collector_class(**collector_kwargs) + collector.load_state_dict(collector_state_dict) + assert collector._frames == collector_frames + assert collector._iter == collector_iter + for _ in enumerate(collector): + raise AssertionError + collector.shutdown() + del collector + + +@pytest.mark.skipif( + IS_OSX, reason="setting different threads across workeres can randomly fail on OSX." +) +def test_num_threads(): + from torchrl.collectors import collectors + + _main_async_collector_saved = collectors._main_async_collector + collectors._main_async_collector = decorate_thread_sub_func( + collectors._main_async_collector, num_threads=3 + ) + num_threads = torch.get_num_threads() + try: + env = ContinuousActionVecMockEnv() + c = MultiSyncDataCollector( + [env], + policy=RandomPolicy(env.action_spec), + num_threads=7, + num_sub_threads=3, + total_frames=200, + frames_per_batch=200, + ) + assert torch.get_num_threads() == 7 + for _ in c: + pass + finally: + try: + c.shutdown() + del c + except Exception: + print("Failed to shut down collector") + # reset vals + collectors._main_async_collector = _main_async_collector_saved + torch.set_num_threads(num_threads) if __name__ == "__main__": diff --git a/test/test_cost.py b/test/test_cost.py index 9849fb09801..5c1a7dbc41c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -13,6 +13,8 @@ from dataclasses import asdict, dataclass from packaging import version as pack_version +from tensordict._tensordict import unravel_keys + from tensordict.nn import ( InteractionType, ProbabilisticTensorDictModule, @@ -20,6 +22,7 @@ ProbabilisticTensorDictSequential, ProbabilisticTensorDictSequential as ProbSeq, TensorDictModule as Mod, + TensorDictSequential, TensorDictSequential as Seq, ) @@ -45,6 +48,7 @@ ) from mocking_classes import ContinuousActionConvMockEnv from tensordict.nn import get_functional, NormalParamExtractor, TensorDictModule +from tensordict.nn.utils import Buffer # from torchrl.data.postprocs.utils import expand_as_right from tensordict.tensordict import assert_allclose_td, TensorDict @@ -68,7 +72,11 @@ SafeSequential, WorldModelWrapper, ) -from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal +from torchrl.modules.distributions.continuous import ( + NormalParamWrapper, + TanhDelta, + TanhNormal, +) from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -86,7 +94,6 @@ QValueModule, ValueOperator, ) -from torchrl.modules.utils import Buffer from torchrl.objectives import ( A2CLoss, ClipPPOLoss, @@ -98,8 +105,10 @@ DreamerActorLoss, DreamerModelLoss, DreamerValueLoss, + DTLoss, IQLLoss, KLPENPPOLoss, + OnlineDTLoss, PPOLoss, QMixerLoss, SACLoss, @@ -163,6 +172,12 @@ def get_devices(): class LossModuleTestBase: + def _flatten_in_keys(self, in_keys): + return [ + in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key))) + for in_key in in_keys + ] + def tensordict_keys_test(self, loss_fn, default_keys, td_est=None): self.tensordict_keys_unknown_key_test(loss_fn) self.tensordict_keys_default_values_test(loss_fn, default_keys) @@ -336,6 +351,7 @@ def _create_mock_data_dqn( action = torch.argmax(action, -1, keepdim=False) reward = torch.randn(batch, 1) done = torch.zeros(batch, 1, dtype=torch.bool) + terminated = torch.zeros(batch, 1, dtype=torch.bool) td = TensorDict( batch_size=(batch,), source={ @@ -343,6 +359,7 @@ def _create_mock_data_dqn( "next": { "observation": next_obs, "done": done, + "terminated": terminated, "reward": reward, }, action_key: action, @@ -380,6 +397,7 @@ def _create_seq_mock_data_dqn( # action_value = action_value.unsqueeze(-1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) if action_spec_type == "categorical": action_value = torch.max(action_value, -1, keepdim=True)[0] @@ -394,6 +412,7 @@ def _create_seq_mock_data_dqn( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -437,7 +456,8 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_value2 = loss_fn.target_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -450,6 +470,19 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + def test_dqn_state_dict(self, delay_value, device, action_spec_type): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) + sd = loss_fn.state_dict() + loss_fn2 = DQNLoss(actor, loss_function="l2", delay_value=delay_value) + loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("n", range(4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @@ -494,7 +527,8 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_value2 = loss_fn.target_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -525,6 +559,7 @@ def test_dqn_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test(loss_fn, default_keys=default_keys) @@ -535,6 +570,7 @@ def test_dqn_tensordict_keys(self, td_est): "value_target": ("value_target", ("value_target", "nested")), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -621,7 +657,8 @@ def test_distributional_dqn( # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_value2 = loss_fn.target_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -640,7 +677,10 @@ def test_distributional_dqn( @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_dqn_notensordict(self, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_dqn_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): n_obs = 3 n_action = 4 action_spec = OneHotDiscreteTensorSpec(n_action) @@ -652,18 +692,20 @@ def test_dqn_notensordict(self, observation_key, reward_key, done_key): in_keys=[observation_key], ) dqn_loss = DQNLoss(actor) - dqn_loss.set_keys(reward=reward_key, done=done_key) + dqn_loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) next_observation = torch.randn(n_obs) action = action_spec.rand() next_reward = torch.randn(1) next_done = torch.zeros(1, dtype=torch.bool) + next_terminated = torch.zeros(1, dtype=torch.bool) kwargs = { observation_key: observation, f"next_{observation_key}": next_observation, f"next_{reward_key}": next_reward, f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") @@ -688,6 +730,7 @@ def test_distributional_dqn_tensordict_keys(self): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", "steps_to_next_obs": "steps_to_next_obs", } @@ -820,6 +863,7 @@ def _create_mock_data_dqn( reward = torch.randn(*batch, 1, device=device) done = torch.zeros(*batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(*batch, 1, dtype=torch.bool, device=device) td = TensorDict( { "agents": TensorDict( @@ -841,6 +885,7 @@ def _create_mock_data_dqn( "state": next_state, "reward": reward, "done": done, + "terminated": terminated, }, batch_size=batch, device=device, @@ -889,7 +934,8 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): - p.data += 3 + if p.requires_grad: + p.data += 3 target_value2 = loss_fn.target_local_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -899,7 +945,8 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += 3 + if p.requires_grad: + p.data += 3 target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -909,9 +956,24 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + def test_qmixer_state_dict(self, delay_value, device, action_spec_type): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + mixer = self._create_mock_mixer(device=device) + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) + sd = loss_fn.state_dict() + loss_fn2 = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) + loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("n", range(4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @@ -956,7 +1018,8 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): - p.data += 3 + if p.requires_grad: + p.data += 3 target_value2 = loss_fn.target_local_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -966,7 +1029,8 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += 3 + if p.requires_grad: + p.data += 3 target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -976,7 +1040,8 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) @pytest.mark.parametrize( @@ -999,6 +1064,7 @@ def test_qmix_tensordict_keys(self, td_est): "action": ("agents", "action"), "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test(loss_fn, default_keys=default_keys) @@ -1009,6 +1075,7 @@ def test_qmix_tensordict_keys(self, td_est): "value_target": ("value_target", ("value_target", "nested")), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -1102,6 +1169,7 @@ def test_mixer_keys( "state": torch.zeros(32, 64, 64, 3), "reward": torch.zeros(32, 1), "done": torch.zeros(32, 1, dtype=torch.bool), + "terminated": torch.zeros(32, 1, dtype=torch.bool), }, [32], ), @@ -1155,20 +1223,20 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): return actor.to(device) def _create_mock_value( - self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None + self, batch=2, obs_dim=3, action_dim=4, state_dim=8, device="cpu", out_keys=None ): # Actor class ValueClass(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(obs_dim + action_dim, 1) + self.linear = nn.Linear(obs_dim + action_dim + state_dim, 1) - def forward(self, obs, act): - return self.linear(torch.cat([obs, act], -1)) + def forward(self, obs, state, act): + return self.linear(torch.cat([obs, state, act], -1)) module = ValueClass() value = ValueOperator( - module=module, in_keys=["observation", "action"], out_keys=out_keys + module=module, in_keys=["observation", "state", "action"], out_keys=out_keys ) return value.to(device) @@ -1204,10 +1272,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -1227,10 +1297,12 @@ def _create_mock_data_ddpg( batch=8, obs_dim=3, action_dim=4, + state_dim=8, atoms=None, device="cpu", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1240,14 +1312,19 @@ def _create_mock_data_ddpg( else: action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) + state = torch.randn(batch, state_dim, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ "observation": obs, + "state": state, "next": { "observation": next_obs, + "state": state, done_key: done, + terminated_key: terminated, reward_key: reward, }, "action": action, @@ -1262,15 +1339,20 @@ def _create_seq_mock_data_ddpg( T=4, obs_dim=3, action_dim=4, + state_dim=8, atoms=None, device="cpu", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + total_state = torch.randn(batch, T + 1, state_dim, device=device) obs = total_obs[:, :T] next_obs = total_obs[:, 1:] + state = total_state[:, :T] + next_state = total_state[:, 1:] if atoms: action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( -1, 1 @@ -1278,15 +1360,20 @@ def _create_seq_mock_data_ddpg( else: action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "state": state.masked_fill_(~mask.unsqueeze(-1), 0.0), "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "state": next_state.masked_fill_(~mask.unsqueeze(-1), 0.0), done_key: done, + terminated_key: terminated, reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -1372,7 +1459,8 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): target_value = [p.clone() for p in loss_fn.target_value_network_params.values()] _i = -1 for _i, p in enumerate(loss_fn.parameters()): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert _i >= 0 target_actor2 = [ p.clone() for p in loss_fn.target_actor_network_params.values() @@ -1396,9 +1484,33 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)]) + def test_ddpg_state_dict(self, delay_actor, delay_value, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + loss_fn = DDPGLoss( + actor, + value, + loss_function="l2", + delay_actor=delay_actor, + delay_value=delay_value, + ) + state_dict = loss_fn.state_dict() + loss_fn2 = DDPGLoss( + actor, + value, + loss_function="l2", + delay_actor=delay_actor, + delay_value=delay_value, + ) + loss_fn2.load_state_dict(state_dict) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [False, True]) def test_ddpg_separate_losses( @@ -1569,6 +1681,7 @@ def test_ddpg_tensordict_keys(self, td_est): default_keys = { "reward": "reward", "done": "done", + "terminated": "terminated", "state_action_value": "state_action_value", "priority": "td_error", } @@ -1589,6 +1702,7 @@ def test_ddpg_tensordict_keys(self, td_est): "state_action_value": ("value", "state_action_value_test"), "reward": ("reward", "reward2"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -1604,12 +1718,15 @@ def test_ddpg_tensordict_run(self, td_est): "priority": "td_error_test", "reward": "reward_test", "done": ("done", "test"), + "terminated": ("terminated", "test"), } actor = self._create_mock_actor() value = self._create_mock_value(out_keys=[tensor_keys["state_action_value"]]) td = self._create_mock_data_ddpg( - reward_key="reward_test", done_key=("done", "test") + reward_key="reward_test", + done_key=("done", "test"), + terminated_key=("terminated", "test"), ) loss_fn = DDPGLoss( actor, @@ -1637,15 +1754,18 @@ def test_ddpg_notensordict(self): "observation": td.get("observation"), "next_reward": td.get(("next", "reward")), "next_done": td.get(("next", "done")), + "next_terminated": td.get(("next", "terminated")), "next_observation": td.get(("next", "observation")), "action": td.get("action"), + "state": td.get("state"), + "next_state": td.get(("next", "state")), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") with pytest.warns(UserWarning, match="No target network updater has been"): loss_val_td = loss(td) loss_val = loss(**kwargs) - for i, key in enumerate(loss_val_td.keys()): + for i, key in enumerate(loss.out_keys): torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) # test select loss.select_out_keys("loss_actor", "target_value") @@ -1746,10 +1866,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -1783,6 +1905,7 @@ def _create_mock_data_td3( observation_key="observation", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1793,6 +1916,7 @@ def _create_mock_data_td3( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -1800,6 +1924,7 @@ def _create_mock_data_td3( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -1823,6 +1948,7 @@ def _create_seq_mock_data_td3( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -1832,6 +1958,7 @@ def _create_seq_mock_data_td3( "observation": next_obs * mask.to(obs.dtype), "reward": reward * mask.to(obs.dtype), "done": done, + "terminated": terminated, }, "collector": {"mask": mask}, "action": action * mask.to(obs.dtype), @@ -1941,7 +2068,65 @@ def test_td3( assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1]) + @pytest.mark.parametrize("noise_clip", [0.1]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_td3_state_dict( + self, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None + else: + bounds = (-1, 1) + action_spec = None + loss_fn = TD3Loss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + sd = loss_fn.state_dict() + loss_fn2 = TD3Loss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + loss_fn2.load_state_dict(sd) @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("device", get_default_devices()) @@ -1957,7 +2142,7 @@ def test_td3_separate_losses( loss_fn = TD3Loss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), minimum=-1, maximum=1), + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), loss_function="l2", separate_losses=separate_losses, ) @@ -2076,8 +2261,16 @@ def test_td3_batcher( sum([item for _, item in loss_ms.items()]).backward() named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" # Check param update effect on targets target_actor = loss_fn.target_actor_network_params.clone().values( @@ -2087,7 +2280,8 @@ def test_td3_batcher( include_nested=True, leaves_only=True ) for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_actor2 = loss_fn.target_actor_network_params.clone().values( include_nested=True, leaves_only=True ) @@ -2115,7 +2309,8 @@ def test_td3_batcher( assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) @pytest.mark.parametrize( @@ -2136,6 +2331,7 @@ def test_td3_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -2154,6 +2350,7 @@ def test_td3_tensordict_keys(self, td_est): "state_action_value": ("value", "state_action_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -2187,22 +2384,29 @@ def test_constructor(self, spec, bounds): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_td3_notensordict(self, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_td3_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) actor = self._create_mock_actor(in_keys=[observation_key]) qvalue = self._create_mock_value( observation_key=observation_key, out_keys=["state_action_value"] ) td = self._create_mock_data_td3( - observation_key=observation_key, reward_key=reward_key, done_key=done_key + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, ) loss = TD3Loss(actor, qvalue, action_spec=actor.spec) - loss.set_keys(reward=reward_key, done=done_key) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) kwargs = { observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), "action": td.get("action"), } @@ -2213,8 +2417,12 @@ def test_td3_notensordict(self, observation_key, reward_key, done_key): loss_val_td = loss(td) torch.manual_seed(0) loss_val = loss(**kwargs) - for i, key in enumerate(loss_val_td.keys()): + for i in loss_val: + assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" + + for i, key in enumerate(loss.out_keys): torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + # test select loss.select_out_keys("loss_actor", "loss_qvalue") torch.manual_seed(0) @@ -2335,10 +2543,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -2375,6 +2585,7 @@ def _create_mock_data_sac( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -2386,6 +2597,7 @@ def _create_mock_data_sac( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -2393,6 +2605,7 @@ def _create_mock_data_sac( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -2416,6 +2629,7 @@ def _create_seq_mock_data_sac( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -2424,6 +2638,7 @@ def _create_seq_mock_data_sac( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -2589,7 +2804,67 @@ def test_sac( assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_value", (True, False)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_sac_state_dict( + self, + delay_value, + delay_actor, + delay_qvalue, + num_qvalue, + device, + version, + ): + if (delay_actor or delay_qvalue) and not delay_value: + pytest.skip("incompatible config") + + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + if delay_value: + kwargs["delay_value"] = True + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + sd = loss_fn.state_dict() + loss_fn2 = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [False, True]) @@ -2758,7 +3033,14 @@ def test_sac_batcher( sum([item for _, item in loss_ms.items()]).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" # Check param update effect on targets target_actor = [ @@ -2781,7 +3063,8 @@ def test_sac_batcher( ) ] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_actor2 = [ p.clone() for p in loss_fn.target_actor_network_params.values( @@ -2830,7 +3113,8 @@ def test_sac_batcher( # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all( (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) ) @@ -2864,6 +3148,7 @@ def test_sac_tensordict_keys(self, td_est, version): "log_prob": "_log_prob", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -2883,6 +3168,7 @@ def test_sac_tensordict_keys(self, td_est, version): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -2890,8 +3176,9 @@ def test_sac_tensordict_keys(self, td_est, version): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_sac_notensordict( - self, action_key, observation_key, reward_key, done_key, version + self, action_key, observation_key, reward_key, done_key, terminated_key, version ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -2899,6 +3186,7 @@ def test_sac_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -2919,13 +3207,19 @@ def test_sac_notensordict( qvalue_network=qvalue, value_network=value, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -2966,6 +3260,49 @@ def test_sac_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + def test_state_dict(self, version): + if version == 1: + pytest.skip("Test not implemented for version 1.") + model = torch.nn.Linear(3, 4) + actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) + policy = ProbabilisticActor( + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=TanhDelta, + ) + value = ValueOperator(module=model, in_keys=["obs"], out_keys="value") + + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + state = loss.state_dict() + + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + # with an access in between + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.target_entropy + state = loss.state_dict() + + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -3013,7 +3350,9 @@ def forward(self, obs): return self.linear(obs) module = ValueClass() - qvalue = ValueOperator(module=module, in_keys=[observation_key]) + qvalue = ValueOperator( + module=module, in_keys=[observation_key], out_keys=["action_value"] + ) return qvalue.to(device) def _create_mock_distributional_actor( @@ -3031,6 +3370,7 @@ def _create_mock_data_sac( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -3048,6 +3388,7 @@ def _create_mock_data_sac( action = (action_value == action_value.max(-1, True)[0]).to(torch.long) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -3055,6 +3396,7 @@ def _create_mock_data_sac( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -3083,6 +3425,7 @@ def _create_seq_mock_data_sac( reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -3091,6 +3434,7 @@ def _create_seq_mock_data_sac( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -3205,7 +3549,59 @@ def test_discrete_sac( assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("target_entropy_weight", [0.5]) + @pytest.mark.parametrize("target_entropy", ["auto"]) + def test_discrete_sac_state_dict( + self, + delay_qvalue, + num_qvalue, + device, + target_entropy_weight, + target_entropy, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + kwargs = {} + if delay_qvalue: + kwargs["delay_qvalue"] = True + + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + num_qvalue_nets=num_qvalue, + target_entropy_weight=target_entropy_weight, + target_entropy=target_entropy, + loss_function="l2", + **kwargs, + ) + sd = loss_fn.state_dict() + loss_fn2 = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + num_qvalue_nets=num_qvalue, + target_entropy_weight=target_entropy_weight, + target_entropy=target_entropy, + loss_function="l2", + **kwargs, + ) + loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("n", list(range(4))) @pytest.mark.parametrize("delay_qvalue", (True, False)) @@ -3273,7 +3669,14 @@ def test_discrete_sac_batcher( sum([item for _, item in loss_ms.items()]).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" # Check param update effect on targets target_actor = [ @@ -3289,7 +3692,8 @@ def test_discrete_sac_batcher( ) ] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_actor2 = [ p.clone() for p in loss_fn.target_actor_network_params.values( @@ -3320,7 +3724,8 @@ def test_discrete_sac_batcher( # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) @pytest.mark.parametrize( @@ -3343,6 +3748,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -3362,6 +3768,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -3369,8 +3776,9 @@ def test_discrete_sac_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_discrete_sac_notensordict( - self, action_key, observation_key, reward_key, done_key + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -3378,6 +3786,7 @@ def test_discrete_sac_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -3392,13 +3801,19 @@ def test_discrete_sac_notensordict( qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -3412,14 +3827,6 @@ def test_discrete_sac_notensordict( torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) - torch.testing.assert_close( - loss_val_td.get("state_action_value_actor"), loss_val[5] - ) - torch.testing.assert_close( - loss_val_td.get("action_log_prob_actor"), loss_val[6] - ) - torch.testing.assert_close(loss_val_td.get("next.state_value"), loss_val[7]) - torch.testing.assert_close(loss_val_td.get("target_value"), loss_val[8]) # test select torch.manual_seed(self.seed) loss.select_out_keys("loss_actor", "loss_alpha") @@ -3520,10 +3927,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -3600,6 +4009,7 @@ def _create_mock_data_redq( action_key="action", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -3610,6 +4020,7 @@ def _create_mock_data_redq( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -3617,6 +4028,7 @@ def _create_mock_data_redq( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -3640,6 +4052,7 @@ def _create_seq_mock_data_redq( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -3648,6 +4061,7 @@ def _create_seq_mock_data_redq( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -3753,7 +4167,40 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_redq_state_dict(self, delay_qvalue, num_qvalue, device): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = REDQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + delay_qvalue=delay_qvalue, + ) + sd = loss_fn.state_dict() + loss_fn2 = REDQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + delay_qvalue=delay_qvalue, + ) + loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("separate_losses", [False, True]) def test_redq_separate_losses(self, separate_losses): @@ -4108,7 +4555,14 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): sum([item for _, item in loss_ms.items()]).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" # Check param update effect on targets target_actor = loss_fn.target_actor_network_params.clone().values( @@ -4118,7 +4572,8 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): include_nested=True, leaves_only=True ) for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_actor2 = loss_fn.target_actor_network_params.clone().values( include_nested=True, leaves_only=True ) @@ -4148,7 +4603,8 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all( (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) ) @@ -4174,6 +4630,7 @@ def test_redq_tensordict_keys(self, td_est): "state_action_value": "state_action_value", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -4192,6 +4649,7 @@ def test_redq_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -4199,9 +4657,10 @@ def test_redq_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize("deprec", [True, False]) def test_redq_notensordict( - self, action_key, observation_key, reward_key, done_key, deprec + self, action_key, observation_key, reward_key, done_key, terminated_key, deprec ): torch.manual_seed(self.seed) td = self._create_mock_data_redq( @@ -4209,6 +4668,7 @@ def test_redq_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -4229,13 +4689,19 @@ def test_redq_notensordict( actor_network=actor, qvalue_network=qvalue, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -4334,6 +4800,7 @@ def _create_mock_data_cql( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -4341,6 +4808,7 @@ def _create_mock_data_cql( "next": { "observation": next_obs, "done": done, + "terminated": terminated, "reward": reward, }, "action": action, @@ -4364,6 +4832,7 @@ def _create_seq_mock_data_cql( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -4372,6 +4841,7 @@ def _create_seq_mock_data_cql( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -4503,7 +4973,64 @@ def test_cql( assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("max_q_backup", [True]) + @pytest.mark.parametrize("deterministic_backup", [True]) + @pytest.mark.parametrize("with_lagrange", [True]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_cql_state_dict( + self, + delay_actor, + delay_qvalue, + max_q_backup, + deterministic_backup, + with_lagrange, + device, + ): + if delay_actor or delay_qvalue: + pytest.skip("incompatible config") + + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + + loss_fn = CQLLoss( + actor_network=actor, + qvalue_network=qvalue, + loss_function="l2", + max_q_backup=max_q_backup, + deterministic_backup=deterministic_backup, + with_lagrange=with_lagrange, + **kwargs, + ) + sd = loss_fn.state_dict() + loss_fn2 = CQLLoss( + actor_network=actor, + qvalue_network=qvalue, + loss_function="l2", + max_q_backup=max_q_backup, + deterministic_backup=deterministic_backup, + with_lagrange=with_lagrange, + **kwargs, + ) + loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("n", list(range(4))) @pytest.mark.parametrize("delay_actor", (True, False)) @@ -4573,7 +5100,14 @@ def test_cql_batcher( sum([item for _, item in loss_ms.items()]).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" # Check param update effect on targets target_actor = [ @@ -4589,7 +5123,8 @@ def test_cql_batcher( ) ] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_actor2 = [ p.clone() for p in loss_fn.target_actor_network_params.values( @@ -4622,7 +5157,8 @@ def test_cql_batcher( # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all( (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) ) @@ -4740,6 +5276,7 @@ def _create_mock_data_ppo( action_key="action", reward_key="reward", done_key="done", + terminated_key="terminated", sample_log_prob_key="sample_log_prob", ): # create a tensordict @@ -4751,6 +5288,7 @@ def _create_mock_data_ppo( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -4758,6 +5296,7 @@ def _create_mock_data_ppo( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -4790,6 +5329,7 @@ def _create_seq_mock_data_ppo( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 @@ -4800,6 +5340,7 @@ def _create_seq_mock_data_ppo( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -4863,8 +5404,8 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): assert "actor" not in name assert "critic" in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) assert counter == 2 value.zero_grad() @@ -4877,11 +5418,24 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) assert counter == 2 actor.zero_grad() + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("gradient_mode", (True,)) + @pytest.mark.parametrize("device", get_default_devices()) + def test_ppo_state_dict(self, loss_class, device, gradient_mode): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + loss_fn = loss_class(actor, value, loss_critic_type="l2") + sd = loss_fn.state_dict() + loss_fn2 = loss_class(actor, value, loss_critic_type="l2") + loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @@ -4933,8 +5487,8 @@ def test_ppo_shared(self, loss_class, device, advantage): assert "actor" not in name assert "critic" in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) assert counter == 2 value.zero_grad() @@ -4947,8 +5501,8 @@ def test_ppo_shared(self, loss_class, device, advantage): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) actor.zero_grad() assert counter == 4 @@ -5120,6 +5674,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -5138,6 +5693,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): "value": ("value", value_key), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -5217,8 +5773,8 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): assert "actor" not in name assert "critic" in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target" in name) + assert ("critic" not in name) or ("target" in name) assert counter == 2 value.zero_grad() @@ -5231,8 +5787,8 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target" in name) + assert ("critic" in name) or ("target" in name) assert counter == 2 actor.zero_grad() @@ -5242,6 +5798,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_ppo_notensordict( self, loss_class, @@ -5250,6 +5807,7 @@ def test_ppo_notensordict( observation_key, reward_key, done_key, + terminated_key, ): torch.manual_seed(self.seed) td = self._create_mock_data_ppo( @@ -5258,6 +5816,7 @@ def test_ppo_notensordict( sample_log_prob_key=sample_log_prob_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(observation_key=observation_key) @@ -5268,6 +5827,7 @@ def test_ppo_notensordict( action=action_key, reward=reward_key, done=done_key, + terminated=terminated_key, sample_log_prob=sample_log_prob_key, ) @@ -5277,6 +5837,7 @@ def test_ppo_notensordict( sample_log_prob_key: td.get(sample_log_prob_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") @@ -5379,10 +5940,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -5418,6 +5981,7 @@ def _create_seq_mock_data_a2c( observation_key="observation", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -5431,6 +5995,7 @@ def _create_seq_mock_data_a2c( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 @@ -5441,6 +6006,7 @@ def _create_seq_mock_data_a2c( "next": { observation_key: next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), done_key: done, + terminated_key: terminated, reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -5511,8 +6077,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): assert "actor" not in name assert "critic" in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) value.zero_grad() loss_objective.backward() @@ -5522,13 +6088,24 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) actor.zero_grad() # test reset loss_fn.reset() + @pytest.mark.parametrize("gradient_mode", (True, False)) + @pytest.mark.parametrize("device", get_default_devices()) + def test_a2c_state_dict(self, device, gradient_mode): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + loss_fn = A2CLoss(actor, value, loss_critic_type="l2") + sd = loss_fn.state_dict() + loss_fn2 = A2CLoss(actor, value, loss_critic_type="l2") + loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("separate_losses", [False, True]) def test_a2c_separate_losses(self, separate_losses): torch.manual_seed(self.seed) @@ -5557,8 +6134,8 @@ def test_a2c_separate_losses(self, separate_losses): assert "actor" not in name assert "critic" in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) else: if p.grad is not None and p.grad.norm() > 0.0: assert ("actor" in name) or ("critic" in name) @@ -5573,8 +6150,8 @@ def test_a2c_separate_losses(self, separate_losses): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) actor.zero_grad() # test reset @@ -5628,8 +6205,8 @@ def test_a2c_diff(self, device, gradient_mode, advantage): assert "actor" not in name assert "critic" in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) for param in params: param.grad = None @@ -5640,8 +6217,8 @@ def test_a2c_diff(self, device, gradient_mode, advantage): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) for param in params: param.grad = None @@ -5667,6 +6244,7 @@ def test_a2c_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -5685,6 +6263,7 @@ def test_a2c_tensordict_keys(self, td_est): "value": ("value", "value_state_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -5699,12 +6278,14 @@ def test_a2c_tensordict_keys_run(self, device): action_key = "action_test" reward_key = "reward_test" done_key = ("done", "test") + terminated_key = ("terminated", "test") td = self._create_seq_mock_data_a2c( device=device, action_key=action_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(device=device) @@ -5721,6 +6302,7 @@ def test_a2c_tensordict_keys_run(self, device): value=value_key, reward=reward_key, done=done_key, + terminated=terminated_key, ) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( @@ -5730,6 +6312,7 @@ def test_a2c_tensordict_keys_run(self, device): action=action_key, reward=reward_key, done=done_key, + terminated=done_key, ) advantage(td) @@ -5744,8 +6327,8 @@ def test_a2c_tensordict_keys_run(self, device): if p.grad is not None and p.grad.norm() > 0.0: assert "actor" not in name if p.grad is None: - assert "actor" in name - assert "critic" not in name + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) value.zero_grad() loss_objective.backward() @@ -5755,8 +6338,8 @@ def test_a2c_tensordict_keys_run(self, device): assert "actor" in name assert "critic" not in name if p.grad is None: - assert "actor" not in name - assert "critic" in name + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) actor.zero_grad() # test reset @@ -5766,7 +6349,10 @@ def test_a2c_tensordict_keys_run(self, device): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_a2c_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) actor = self._create_mock_actor(observation_key=observation_key) @@ -5776,16 +6362,23 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) loss = A2CLoss(actor, value) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { observation_key: td.get(observation_key), f"next_{observation_key}": td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), action_key: td.get(action_key), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -5876,6 +6469,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est "observation": torch.randn(batch, n_obs), "reward": torch.randn(batch, 1), "done": torch.zeros(batch, 1, dtype=torch.bool), + "terminated": torch.zeros(batch, 1, dtype=torch.bool), }, "action": torch.randn(batch, n_act), }, @@ -5959,6 +6553,7 @@ def test_reinforce_tensordict_keys(self, td_est): "sample_log_prob": "sample_log_prob", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -5982,6 +6577,7 @@ def test_reinforce_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -6014,10 +6610,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -6114,8 +6712,9 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_reinforce_notensordict( - self, action_key, observation_key, reward_key, done_key + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) n_obs = 3 @@ -6134,19 +6733,26 @@ def test_reinforce_notensordict( spec=UnboundedContinuousTensorSpec(n_act), ) loss = ReinforceLoss(actor=actor_net, critic=value_net) - loss.set_keys(reward=reward_key, done=done_key, action=action_key) + loss.set_keys( + reward=reward_key, + done=done_key, + action=action_key, + terminated=terminated_key, + ) observation = torch.randn(batch, n_obs) action = torch.randn(batch, n_act) next_reward = torch.randn(batch, 1) next_observation = torch.randn(batch, n_obs) next_done = torch.zeros(batch, 1, dtype=torch.bool) + next_terminated = torch.zeros(batch, 1, dtype=torch.bool) kwargs = { action_key: action, observation_key: observation, f"next_{reward_key}": next_reward, f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, f"next_{observation_key}": next_observation, } td = TensorDict(kwargs, [batch]).unflatten_keys("_") @@ -6187,6 +6793,9 @@ def _create_world_model_data( ), "reward": torch.randn(batch_size, temporal_length, 1), "done": torch.zeros(batch_size, temporal_length, dtype=torch.bool), + "terminated": torch.zeros( + batch_size, temporal_length, dtype=torch.bool + ), }, "action": torch.randn(batch_size, temporal_length, 64), }, @@ -6611,6 +7220,7 @@ def test_dreamer_actor_tensordict_keys(self, td_est, device): "reward": "reward", "value": "state_value", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -6637,130 +7247,506 @@ def test_dreamer_value_tensordict_keys(self, device): self.tensordict_keys_test(loss_fn, default_keys=default_keys) -@pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" -) -class TestIQL(LossModuleTestBase): +class TestOnlineDT(LossModuleTestBase): seed = 0 - def _create_mock_actor( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - observation_key="observation", - ): + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=["observation"], out_keys=["loc", "scale"] ) actor = ProbabilisticActor( module=module, + distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, - distribution_class=TanhNormal, ) return actor.to(device) - def _create_mock_qvalue( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - out_keys=None, - observation_key="observation", - action_key="action", - ): - class ValueClass(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(obs_dim + action_dim, 1) - - def forward(self, obs, act): - return self.linear(torch.cat([obs, act], -1)) - - module = ValueClass() - qvalue = ValueOperator( - module=module, - in_keys=[observation_key, action_key], - out_keys=out_keys, + def _create_mock_data_odt(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, 1, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + "reward2go": reward2go, + }, + device=device, ) - return qvalue.to(device) + return td - def _create_mock_value( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - out_keys=None, - observation_key="observation", + def _create_seq_mock_data_odt( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" ): - module = nn.Linear(obs_dim, 1) - value = ValueOperator( - module=module, in_keys=[observation_key], out_keys=out_keys - ) - return value.to(device) + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, T, 1, device=device) - def _create_mock_common_layer_setup( - self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 - ): - common_net = MLP( - num_cells=ncells, - in_features=n_obs, - depth=3, - out_features=n_hidden, - ) - actor_net = MLP( - num_cells=ncells, - in_features=n_hidden, - depth=1, - out_features=2 * n_act, - ) - value_net = MLP( - in_features=n_hidden, - num_cells=ncells, - depth=1, - out_features=1, - ) - qvalue_net = MLP( - in_features=n_hidden + n_act, - num_cells=ncells, - depth=1, - out_features=1, - ) - batch = [batch, T] td = TensorDict( - { - "obs": torch.randn(*batch, n_obs), - "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), - "done": torch.zeros(*batch, 1, dtype=torch.bool), - "next": { - "obs": torch.randn(*batch, n_obs), - "reward": torch.randn(*batch, 1), - "done": torch.zeros(*batch, 1, dtype=torch.bool), - }, + batch_size=(batch, T), + source={ + "observation": obs, + "reward": reward2go, + "action": action, }, - batch, - names=[None, "time"], + device=device, ) - common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) - actor = ProbSeq( - common, - Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), - Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), - ProbMod( - in_keys=["loc", "scale"], - out_keys=["action"], - distribution_class=TanhNormal, - ), + return td + + @pytest.mark.parametrize("device", get_available_devices()) + def test_odt(self, device): + torch.manual_seed(self.seed) + td = self._create_mock_data_odt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = OnlineDTLoss(actor) + loss = loss_fn(td) + loss_transformer = sum( + loss[key] + for key in loss.keys() + if key.startswith("loss") and key != "loss_alpha" + ) + loss_alpha = loss["loss_alpha"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + loss_alpha.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "alpha" in name + if p.grad is None: + assert "actor" in name + assert "alpha" not in name + loss_fn.zero_grad() + + sum([loss_transformer, loss_alpha]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_odt_state_dict(self, device): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + + loss_fn = OnlineDTLoss(actor) + sd = loss_fn.state_dict() + loss_fn2 = OnlineDTLoss(actor) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_seq_odt(self, device): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_odt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = OnlineDTLoss(actor) + loss = loss_fn(td) + loss_transformer = sum( + loss[key] + for key in loss.keys() + if key.startswith("loss") and key != "loss_alpha" + ) + loss_alpha = loss["loss_alpha"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + loss_alpha.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "alpha" in name + if p.grad is None: + assert "actor" in name + assert "alpha" not in name + loss_fn.zero_grad() + + sum([loss_transformer, loss_alpha]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + def test_onlinedt_tensordict_keys(self): + actor = self._create_mock_actor() + loss_fn = OnlineDTLoss(actor) + + default_keys = { + "action": "action", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + def test_onlinedt_notensordict(self, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + td = self._create_mock_data_odt(device=device) + loss_fn = OnlineDTLoss(actor) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(td.flatten_keys("_").select(*in_keys)) + + torch.manual_seed(0) + loss_val_td = loss_fn(td) + torch.manual_seed(0) + loss_log_likelihood, loss_entropy, loss_alpha, alpha, entropy = loss_fn( + **kwargs + ) + torch.testing.assert_close( + loss_val_td.get("loss_log_likelihood"), loss_log_likelihood + ) + torch.testing.assert_close(loss_val_td.get("loss_entropy"), loss_entropy) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_alpha) + # test select + torch.manual_seed(0) + loss_fn.select_out_keys("loss_entropy") + if torch.__version__ >= "2.0.0": + loss_entropy = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_entropy = loss_fn(**kwargs) + return + assert loss_entropy == loss_val_td["loss_entropy"] + + +class TestDT(LossModuleTestBase): + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"]) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhDelta, + in_keys=["param"], + spec=action_spec, + ) + return actor.to(device) + + def _create_mock_data_dt(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, 1, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_dt( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs, + "action": action, + }, + device=device, + ) + return td + + def test_dt_tensordict_keys(self): + actor = self._create_mock_actor() + loss_fn = DTLoss(actor) + + default_keys = { + "action": "action", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + def test_dt_notensordict(self, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + td = self._create_mock_data_dt(device=device) + loss_fn = DTLoss(actor) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(td.flatten_keys("_").select(*in_keys)) + + loss_val_td = loss_fn(td) + loss_val = loss_fn(**kwargs) + torch.testing.assert_close(loss_val_td.get("loss"), loss_val) + # test select + loss_fn.select_out_keys("loss") + if torch.__version__ >= "2.0.0": + loss_actor = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor = loss_fn(**kwargs) + return + assert loss_actor == loss_val_td["loss"] + + @pytest.mark.parametrize("device", get_available_devices()) + def test_dt(self, device): + torch.manual_seed(self.seed) + td = self._create_mock_data_dt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_dt_state_dict(self, device): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + sd = loss_fn.state_dict() + loss_fn2 = DTLoss(actor) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_seq_dt(self, device): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_dt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestIQL(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, + ) + return actor.to(device) + + def _create_mock_qvalue( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + observation_key="observation", + action_key="action", + ): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + qvalue = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return qvalue.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + observation_key="observation", + ): + module = nn.Linear(obs_dim, 1) + value = ValueOperator( + module=module, in_keys=[observation_key], out_keys=out_keys + ) + return value.to(device) + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 + ): + common_net = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + value_net = MLP( + in_features=n_hidden, + num_cells=ncells, + depth=1, + out_features=1, + ) + qvalue_net = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch, T] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "sample_log_prob": torch.randn(*batch), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + names=[None, "time"], + ) + common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ), ) value = Seq( common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]) @@ -6792,6 +7778,7 @@ def _create_mock_data_iql( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -6803,6 +7790,7 @@ def _create_mock_data_iql( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -6810,6 +7798,7 @@ def _create_mock_data_iql( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -6833,6 +7822,7 @@ def _create_seq_mock_data_iql( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -6841,6 +7831,7 @@ def _create_seq_mock_data_iql( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -6966,8 +7957,53 @@ def test_iql( assert len({p for n, p in named_parameters}) == len(list(named_parameters)) assert len({p for n, p in named_buffers}) == len(list(named_buffers)) - for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("temperature", [0.0]) + @pytest.mark.parametrize("expectile", [0.1]) + def test_iql_state_dict( + self, + num_qvalue, + device, + temperature, + expectile, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = self._create_mock_value(device=device) + + loss_fn = IQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + temperature=temperature, + expectile=expectile, + loss_function="l2", + ) + sd = loss_fn.state_dict() + loss_fn2 = IQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + temperature=temperature, + expectile=expectile, + loss_function="l2", + ) + loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("separate_losses", [False, True]) def test_iql_separate_losses(self, separate_losses): @@ -7189,7 +8225,14 @@ def test_iql_batcher( sum([item for _, item in loss_ms.items()]).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: - assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" # Check param update effect on targets target_qvalue = [ @@ -7199,7 +8242,8 @@ def test_iql_batcher( ) ] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) target_qvalue2 = [ p.clone() for p in loss_fn.target_qvalue_network_params.values( @@ -7218,7 +8262,8 @@ def test_iql_batcher( # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) @pytest.mark.parametrize( @@ -7244,6 +8289,7 @@ def test_iql_tensordict_keys(self, td_est): "value": "state_value", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -7263,6 +8309,7 @@ def test_iql_tensordict_keys(self, td_est): key_mapping = { "value": ("value", "value_test"), "done": ("done", "done_test"), + "terminated": ("terminated", "terminated_test"), "reward": ("reward", ("reward", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -7271,13 +8318,17 @@ def test_iql_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_iql_notensordict(self, action_key, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_iql_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) td = self._create_mock_data_iql( action_key=action_key, observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(observation_key=observation_key) @@ -7289,13 +8340,19 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke value = self._create_mock_value(observation_key=observation_key) loss = IQLLoss(actor_network=actor, qvalue_network=qvalue, value_network=value) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -7328,7 +8385,10 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke @pytest.mark.parametrize("create_target_params", [True, False]) -def test_param_buffer_types(create_target_params): +@pytest.mark.parametrize( + "cast", [None, torch.float, torch.double, *get_default_devices()] +) +def test_param_buffer_types(create_target_params, cast): class MyLoss(LossModule): def __init__(self, actor_network): super().__init__() @@ -7347,16 +8407,25 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: out_keys=["action"], ) loss = MyLoss(actor_module) - assert isinstance(loss.actor_network_params["module", "0", "weight"], nn.Parameter) - assert isinstance( - loss.target_actor_network_params["module", "0", "weight"], nn.Parameter - ) - assert loss.actor_network_params["module", "0", "weight"].requires_grad - assert not loss.target_actor_network_params["module", "0", "weight"].requires_grad - assert isinstance(loss.actor_network_params["module", "0", "bias"], nn.Parameter) - assert isinstance( - loss.target_actor_network_params["module", "0", "bias"], nn.Parameter - ) + if cast is not None: + loss.to(cast) + for name in ("weight", "bias"): + param = loss.actor_network_params["module", "0", name] + assert isinstance(param, nn.Parameter) + target = loss.target_actor_network_params["module", "0", name] + if create_target_params: + assert target.data_ptr() != param.data_ptr() + else: + assert target.data_ptr() == param.data_ptr() + assert param.requires_grad + assert not target.requires_grad + if cast is not None: + if isinstance(cast, torch.dtype): + assert param.dtype == cast + assert target.dtype == cast + else: + assert param.device == cast + assert target.device == cast if create_target_params: assert ( @@ -7451,7 +8520,7 @@ def __init__(self): module = custom_module_error().to(device) with pytest.raises( - RuntimeError, match="Your module seems to have a target tensor list " + ValueError, match="The loss_module must be a LossModule instance" ): if mode == "hard": upd = HardUpdate( @@ -7482,7 +8551,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: pass module = custom_module(delay_module=False) - with pytest.raises(RuntimeError, match="The target and source data are identical"): + with pytest.raises( + RuntimeError, + match="Did not find any target parameters or buffers in the loss module", + ): if mode == "hard": upd = HardUpdate( module, value_network_update_interval=value_network_update_interval @@ -7495,8 +8567,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: else: raise NotImplementedError - with pytest.warns(UserWarning, match="No target network updater has been"): - module = custom_module().to(device).to(dtype) + # this is now allowed + # with pytest.warns(UserWarning, match="No target network updater has been"): + # module = custom_module().to(device).to(dtype) if mode == "soft": with pytest.raises(ValueError, match="One and only one argument"): @@ -7506,6 +8579,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: tau=0.1, ) + module = custom_module(delay_module=True) _ = module.module1_params with pytest.warns(UserWarning, match="No target network updater has been"): _ = module.target_module1_params @@ -7596,24 +8670,77 @@ class TestValues: @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) - @pytest.mark.parametrize("T", [3, 5, 200]) + @pytest.mark.parametrize("T", [200, 5, 3]) # @pytest.mark.parametrize("random_gamma,rolling_gamma", [[True, False], [True, True], [False, None]]) @pytest.mark.parametrize("random_gamma,rolling_gamma", [[False, None]]) def test_tdlambda(self, device, gamma, lmbda, N, T, random_gamma, rolling_gamma): torch.manual_seed(0) - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) - next_state_value = torch.randn(*N, T, 1, device=device) if random_gamma: gamma = torch.rand_like(reward) * gamma + next_state_value = torch.cat( + [state_value[..., 1:, :], torch.randn_like(state_value[..., -1:, :])], -2 + ) + r1 = vec_td_lambda_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + r2 = td_lambda_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + r3, *_ = vec_generalized_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) + torch.testing.assert_close(r3, r2, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) + + # test when v' is not v from next step (not working with gae) + next_state_value = torch.randn_like(next_state_value) r1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done, rolling_gamma + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) r2 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done, rolling_gamma + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -7630,7 +8757,9 @@ def test_tdlambda_multi( torch.manual_seed(0) D = feature_dim time_dim = -1 - len(D) - done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device) state_value = torch.randn(*N, T, *D, device=device) next_state_value = torch.randn(*N, T, *D, device=device) @@ -7643,8 +8772,9 @@ def test_tdlambda_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) r2 = td_lambda_advantage_estimate( @@ -7653,8 +8783,9 @@ def test_tdlambda_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) if len(D) == 2: @@ -7666,8 +8797,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7683,8 +8815,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7701,8 +8834,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7717,8 +8851,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7738,7 +8873,9 @@ def test_tdlambda_multi( def test_td1(self, device, gamma, N, T, random_gamma, rolling_gamma): torch.manual_seed(0) - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -7746,10 +8883,22 @@ def test_td1(self, device, gamma, N, T, random_gamma, rolling_gamma): gamma = torch.rand_like(reward) * gamma r1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done, rolling_gamma + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) r2 = td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done, rolling_gamma + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -7766,7 +8915,9 @@ def test_td1_multi( D = feature_dim time_dim = -1 - len(D) - done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device) state_value = torch.randn(*N, T, *D, device=device) next_state_value = torch.randn(*N, T, *D, device=device) @@ -7778,8 +8929,9 @@ def test_td1_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) r2 = td1_advantage_estimate( @@ -7787,8 +8939,9 @@ def test_td1_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) if len(D) == 2: @@ -7799,8 +8952,9 @@ def test_td1_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7815,8 +8969,9 @@ def test_td1_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7832,8 +8987,9 @@ def test_td1_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7847,8 +9003,9 @@ def test_td1_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -7866,22 +9023,36 @@ def test_td1_multi( @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) @pytest.mark.parametrize("T", [200, 5, 3]) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) - @pytest.mark.parametrize("has_done", [True, False]) + @pytest.mark.parametrize("has_done", [False, True]) def test_gae(self, device, gamma, lmbda, N, T, dtype, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device, dtype=dtype) state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) r1 = vec_generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) r2 = generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -7906,8 +9077,10 @@ def test_gae_param_as_tensor( T = 200 done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device, dtype=dtype) state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) @@ -7927,10 +9100,22 @@ def test_gae_param_as_tensor( lmbda_vec = lmbda r1 = vec_generalized_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) r2 = generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -7951,8 +9136,10 @@ def test_gae_multidim( torch.manual_seed(0) done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device, dtype=dtype) state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) @@ -7963,7 +9150,8 @@ def test_gae_multidim( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, time_dim=time_dim, ) r2 = generalized_advantage_estimate( @@ -7972,7 +9160,8 @@ def test_gae_multidim( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, time_dim=time_dim, ) if len(D) == 2: @@ -7983,7 +9172,8 @@ def test_gae_multidim( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], time_dim=-2, ) for i in range(D[0]) @@ -7996,7 +9186,8 @@ def test_gae_multidim( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + done=done[..., i : i + 1, j], time_dim=-2, ) for i in range(D[0]) @@ -8010,7 +9201,8 @@ def test_gae_multidim( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], time_dim=-2, ) for i in range(D[0]) @@ -8022,7 +9214,8 @@ def test_gae_multidim( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], time_dim=-2, ) for i in range(D[0]) @@ -8043,7 +9236,7 @@ def test_gae_multidim( @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) - @pytest.mark.parametrize("T", [3, 5, 200]) + @pytest.mark.parametrize("T", [200, 5, 3]) @pytest.mark.parametrize("has_done", [True, False]) def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): """Tests vec_td_lambda_advantage_estimate against itself with @@ -8053,32 +9246,61 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - + # if len(N) == 2: + # print(terminated[4, 0, -10:]) + # print(done[4, 0, -10:]) v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8104,8 +9326,10 @@ def test_tdlambda_tensor_gamma_single_element( torch.manual_seed(0) done = torch.zeros(*N, T, F, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, F, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, F, device=device) state_value = torch.randn(*N, T, F, device=device) next_state_value = torch.randn(*N, T, F, device=device) @@ -8123,22 +9347,47 @@ def test_tdlambda_tensor_gamma_single_element( lmbda_vec = lmbda v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8156,8 +9405,10 @@ def test_td1_tensor_gamma(self, device, gamma, N, T, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -8165,23 +9416,44 @@ def test_td1_tensor_gamma(self, device, gamma, N, T, has_done): gamma_tensor = torch.full((*N, T, 1), gamma, device=device) v1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8203,8 +9475,10 @@ def test_vectdlambda_tensor_gamma( torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -8212,23 +9486,48 @@ def test_vectdlambda_tensor_gamma( gamma_tensor = torch.full((*N, T, 1), gamma, device=device) v1 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8249,28 +9548,55 @@ def test_vectd1_tensor_gamma( torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - v1 = td1_advantage_estimate(gamma, state_value, next_state_value, reward, done) + v1 = td1_advantage_estimate( + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 - v1 = td1_advantage_estimate(gamma, state_value, next_state_value, reward, done) + v1 = td1_advantage_estimate( + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8292,8 +9618,10 @@ def test_vectdlambda_rand_gamma( torch.manual_seed(seed) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -8307,8 +9635,9 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) if rolling_gamma is False and not done[..., 1:, :][done[..., :-1, :]].all(): # if a not-done follows a done, then rolling_gamma=False cannot be used @@ -8321,8 +9650,24 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + return + elif rolling_gamma is False: + with pytest.raises( + NotImplementedError, match=r"The vectorized version of TD" + ): + vec_td_lambda_advantage_estimate( + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td_lambda_advantage_estimate( @@ -8331,8 +9676,9 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8352,8 +9698,10 @@ def test_vectd1_rand_gamma( torch.manual_seed(seed) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -8366,10 +9714,14 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) - if rolling_gamma is False and not done[..., 1:, :][done[..., :-1, :]].all(): + if ( + rolling_gamma is False + and not terminated[..., 1:, :][terminated[..., :-1, :]].all() + ): # if a not-done follows a done, then rolling_gamma=False cannot be used with pytest.raises( NotImplementedError, match="When using rolling_gamma=False" @@ -8379,8 +9731,23 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + return + elif rolling_gamma is False: + with pytest.raises( + NotImplementedError, match="The vectorized version of TD" + ): + vec_td1_advantage_estimate( + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td1_advantage_estimate( @@ -8388,8 +9755,9 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8441,8 +9809,10 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): lmbda = torch.rand([]).item() - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) - done[..., T // 2 - 1, :] = 1 + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated[..., T // 2 - 1, :] = 1 + done = terminated.clone() + done[..., -1, :] = 1 reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -8457,8 +9827,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) v1a = td_lambda_advantage_estimate( gamma_tensor[..., : T // 2, :], @@ -8466,8 +9837,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], - rolling_gamma, + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], + rolling_gamma=rolling_gamma, ) v1b = td_lambda_advantage_estimate( gamma_tensor[..., T // 2 :, :], @@ -8475,8 +9847,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], - rolling_gamma, + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -8490,8 +9863,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td_lambda_advantage_estimate( @@ -8500,8 +9874,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) v2a = vec_td_lambda_advantage_estimate( gamma_tensor[..., : T // 2, :], @@ -8509,8 +9884,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], - rolling_gamma, + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], + rolling_gamma=rolling_gamma, ) v2b = vec_td_lambda_advantage_estimate( gamma_tensor[..., T // 2 :, :], @@ -8518,8 +9894,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], - rolling_gamma, + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8531,22 +9908,17 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("N", [(3,), (3, 7)]) @pytest.mark.parametrize("T", [3, 5, 200]) - def test_successive_traj_tdadv( - self, - device, - N, - T, - ): + def test_successive_traj_tdadv(self, device, N, T): """Tests td_lambda_advantage_estimate against vec_td_lambda_advantage_estimate with gamma being a random tensor """ torch.manual_seed(0) - lmbda = torch.rand([]).item() - + # for td0, a done that is not terminated has no effect done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) done[..., T // 2 - 1, :] = 1 + terminated = done.clone() reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -8560,21 +9932,24 @@ def test_successive_traj_tdadv( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, ) v1a = td0_advantage_estimate( gamma_tensor[..., : T // 2, :], state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], ) v1b = td0_advantage_estimate( gamma_tensor[..., T // 2 :, :], state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], ) torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -8595,8 +9970,10 @@ def test_successive_traj_gae( lmbda = torch.rand([]).item() - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) - done[..., T // 2 - 1, :] = 1 + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated[..., T // 2 - 1, :] = 1 + done = terminated.clone() + done[..., -1, :] = 1 reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -8611,7 +9988,8 @@ def test_successive_traj_gae( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, )[0] v1a = generalized_advantage_estimate( gamma_tensor, @@ -8619,7 +9997,8 @@ def test_successive_traj_gae( state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], )[0] v1b = generalized_advantage_estimate( gamma_tensor, @@ -8627,7 +10006,8 @@ def test_successive_traj_gae( state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], )[0] torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -8637,7 +10017,8 @@ def test_successive_traj_gae( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, )[0] v2a = vec_generalized_advantage_estimate( gamma_tensor, @@ -8645,7 +10026,8 @@ def test_successive_traj_gae( state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], )[0] v2b = vec_generalized_advantage_estimate( gamma_tensor, @@ -8653,7 +10035,8 @@ def test_successive_traj_gae( state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], )[0] torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) torch.testing.assert_close(v2, torch.cat([v2a, v2b], -2), rtol=1e-4, atol=1e-4) @@ -8731,9 +10114,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: loss = MyLoss(actor_network, value_network) # modify params for p in loss.parameters(): - p.data += torch.randn_like(p) + if p.requires_grad: + p.data += torch.randn_like(p) - assert len(list(loss.parameters())) == 6 + assert len([p for p in loss.parameters() if p.requires_grad]) == 6 assert ( len(loss.actor_network_params.keys(include_nested=True, leaves_only=True)) == 4 ) @@ -9282,6 +10666,114 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) assert (value != value_).all() +def test_instantiate_with_different_keys(): + loss_1 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_1.set_keys(reward="a") + assert loss_1.tensor_keys.reward == "a" + loss_2 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_2.set_keys(reward="b") + assert loss_1.tensor_keys.reward == "a" + + +class TestBuffer: + # @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half)) + # def test_param_cast(self, dtype): + # param = nn.Parameter(torch.zeros(3)) + # idb = param.data_ptr() + # param = param.to(dtype) + # assert param.data_ptr() == idb + # assert param.dtype == dtype + # assert param.data.dtype == dtype + # @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half)) + # def test_buffer_cast(self, dtype): + # buffer = Buffer(torch.zeros(3)) + # idb = buffer.data_ptr() + # buffer = buffer.to(dtype) + # assert isinstance(buffer, Buffer) + # assert buffer.data_ptr() == idb + # assert buffer.dtype == dtype + # assert buffer.data.dtype == dtype + + @pytest.mark.parametrize("create_target_params", [True, False]) + @pytest.mark.parametrize( + "dest", [torch.float, torch.double, torch.half, *get_default_devices()] + ) + def test_module_cast(self, create_target_params, dest): + # test that when casting a loss module, all the tensors (params and buffers) + # are properly cast + class DummyModule(LossModule): + def __init__(self): + common = nn.Linear(3, 4) + actor = nn.Linear(4, 4) + value = nn.Linear(4, 1) + common = TensorDictModule(common, in_keys=["obs"], out_keys=["hidden"]) + actor = TensorDictSequential( + common, + TensorDictModule(actor, in_keys=["hidden"], out_keys=["action"]), + ) + value = TensorDictSequential( + common, + TensorDictModule(value, in_keys=["hidden"], out_keys=["value"]), + ) + super().__init__() + self.convert_to_functional( + actor, + "actor", + expand_dim=None, + create_target_params=False, + compare_against=None, + ) + self.convert_to_functional( + value, + "value", + expand_dim=2, + create_target_params=create_target_params, + compare_against=actor.parameters(), + ) + + mod = DummyModule() + v_p1 = set(mod.value_params.values(True, True)).union( + set(mod.actor_params.values(True, True)) + ) + v_params1 = set(mod.parameters()) + v_buffers1 = set(mod.buffers()) + mod.to(dest) + v_p2 = set(mod.value_params.values(True, True)).union( + set(mod.actor_params.values(True, True)) + ) + v_params2 = set(mod.parameters()) + v_buffers2 = set(mod.buffers()) + assert v_p1 == v_p2 + assert v_params1 == v_params2 + assert v_buffers1 == v_buffers2 + for p in mod.parameters(): + assert isinstance(p, nn.Parameter) + for p in mod.buffers(): + assert isinstance(p, Buffer) + for p in mod.actor_params.values(True, True): + assert isinstance(p, (nn.Parameter, Buffer)) + for p in mod.value_params.values(True, True): + assert isinstance(p, (nn.Parameter, Buffer)) + if isinstance(dest, torch.dtype): + for p in mod.parameters(): + assert p.dtype == dest + for p in mod.buffers(): + assert p.dtype == dest + for p in mod.actor_params.values(True, True): + assert p.dtype == dest + for p in mod.value_params.values(True, True): + assert p.dtype == dest + else: + for p in mod.parameters(): + assert p.device == dest + for p in mod.buffers(): + assert p.device == dest + for p in mod.actor_params.values(True, True): + assert p.device == dest + for p in mod.value_params.values(True, True): + assert p.device == dest + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_distributed.py b/test/test_distributed.py index 9b18c709436..8dcbe33f79d 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -86,8 +86,9 @@ def _start_worker(cls): @classmethod def _test_distributed_collector_basic(cls, queue, frames_per_batch): cls._start_worker() - env = ContinuousActionVecMockEnv() - policy = RandomPolicy(env.action_spec) + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + print("creating collector") collector = cls.distributed_class()( [env] * 2, policy, @@ -126,8 +127,8 @@ def test_distributed_collector_basic(self, frames_per_batch): @classmethod def _test_distributed_collector_mult(cls, queue, frames_per_batch): cls._start_worker() - env = ContinuousActionVecMockEnv() - policy = RandomPolicy(env.action_spec) + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) collector = cls.distributed_class()( [env] * 2, policy, @@ -164,8 +165,8 @@ def test_distributed_collector_mult(self, frames_per_batch=200): @classmethod def _test_distributed_collector_sync(cls, queue, sync): frames_per_batch = 50 - env = ContinuousActionVecMockEnv() - policy = RandomPolicy(env.action_spec) + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) collector = cls.distributed_class()( [env] * 2, policy, @@ -203,8 +204,8 @@ def test_distributed_collector_sync(self, sync): @classmethod def _test_distributed_collector_class(cls, queue, collector_class): frames_per_batch = 50 - env = ContinuousActionVecMockEnv() - policy = RandomPolicy(env.action_spec) + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) collector = cls.distributed_class()( [env] * 2, policy, @@ -250,7 +251,7 @@ def test_distributed_collector_class(self, collector_class): def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): frames_per_batch = 50 total_frames = 300 - env = CountingEnv() + env = CountingEnv policy = CountingPolicy() if collector_class is MultiaSyncDataCollector: # otherwise we may collect data from a collector that has not yet been @@ -293,13 +294,7 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): MultiaSyncDataCollector, ], ) - @pytest.mark.parametrize( - "sync", - [ - False, - True, - ], - ) + @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_updatepolicy(self, collector_class, sync): """Testing various collector classes to be used in nodes.""" queue = mp.Queue(1) @@ -369,7 +364,7 @@ def _test_distributed_collector_updatepolicy( ): frames_per_batch = 50 total_frames = 300 - env = CountingEnv() + env = CountingEnv policy = CountingPolicy() collector = cls.distributed_class()( [env] * 2, @@ -461,8 +456,8 @@ def _start_worker(cls): @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_sync(self, sync, frames_per_batch=200): frames_per_batch = 50 - env = ContinuousActionVecMockEnv() - policy = RandomPolicy(env.action_spec) + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) collector = self.distributed_class()( [env] * 2, policy, @@ -488,8 +483,8 @@ def test_distributed_collector_sync(self, sync, frames_per_batch=200): ) def test_distributed_collector_class(self, collector_class): frames_per_batch = 50 - env = ContinuousActionVecMockEnv() - policy = RandomPolicy(env.action_spec) + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) collector = self.distributed_class()( [env] * 2, policy, @@ -513,17 +508,11 @@ def test_distributed_collector_class(self, collector_class): MultiaSyncDataCollector, ], ) - @pytest.mark.parametrize( - "sync", - [ - False, - True, - ], - ) + @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_updatepolicy(self, collector_class, sync): frames_per_batch = 50 total_frames = 300 - env = CountingEnv() + env = CountingEnv policy = CountingPolicy() if collector_class is MultiaSyncDataCollector: # otherwise we may collect data from a collector that has not yet been diff --git a/test/test_distributions.py b/test/test_distributions.py index 82b08e81546..30bb0288dd4 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -15,49 +15,61 @@ from torchrl.modules import ( NormalParamWrapper, OneHotCategorical, + ReparamGradientStrategy, TanhNormal, TruncatedNormal, ) -from torchrl.modules.distributions import Delta, MaskedCategorical, TanhDelta +from torchrl.modules.distributions import ( + Delta, + MaskedCategorical, + MaskedOneHotCategorical, + TanhDelta, +) from torchrl.modules.distributions.continuous import SafeTanhTransform @pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required") @pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("div_up", [1, 2]) -@pytest.mark.parametrize("div_down", [1, 2]) -def test_delta(device, div_up, div_down): - x = torch.randn(1000000, 4, device=device, dtype=torch.double) - d = Delta(x) - assert d.log_prob(d.mode).shape == x.shape[:-1] - assert (d.log_prob(d.mode) == float("inf")).all() - - x = torch.randn(1000000, 4, device=device, dtype=torch.double) - d = TanhDelta(x, -1 / div_down, 1.0 / div_up, atol=1e-4, rtol=1e-4) - xinv = d.transforms[0].inv(d.mode) - assert d.base_dist._is_equal(xinv).all() - assert d.log_prob(d.mode).shape == x.shape[:-1] - assert (d.log_prob(d.mode) == float("inf")).all() - - x = torch.randn(1000000, 4, device=device, dtype=torch.double) - d = TanhDelta( - x, - -torch.ones_like(x) / div_down, - torch.ones_like(x) / div_up, - atol=1e-4, - rtol=1e-4, - ) - xinv = d.transforms[0].inv(d.mode) - assert d.base_dist._is_equal(xinv).all() - assert d.log_prob(d.mode).shape == x.shape[:-1] - assert (d.log_prob(d.mode) == float("inf")).all() +class TestDelta: + def test_delta_logprob(self, device): + x = torch.randn(1000000, 4, device=device, dtype=torch.double) + d = Delta(x) + assert d.log_prob(d.mode).shape == x.shape[:-1] + assert (d.log_prob(d.mode) == float("inf")).all() + + @pytest.mark.parametrize("div_up", [1, 2]) + @pytest.mark.parametrize("div_down", [1, 2]) + def test_tanhdelta_logprob(self, device, div_up, div_down): + x = torch.randn(1000000, 4, device=device, dtype=torch.double) + d = TanhDelta(x, -1 / div_down, 1.0 / div_up, atol=1e-4, rtol=1e-4) + xinv = d.transforms[0].inv(d.mode) + assert d.base_dist._is_equal(xinv).all() + assert d.log_prob(d.mode).shape == x.shape[:-1] + assert (d.log_prob(d.mode) == float("inf")).all() + + @pytest.mark.parametrize("div_up", [1, 2]) + @pytest.mark.parametrize("div_down", [1, 2]) + def test_tanhdelta_inv(self, device, div_up, div_down): + x = torch.randn(1000000, 4, device=device, dtype=torch.double) + d = TanhDelta( + x, + -torch.ones_like(x) / div_down, + torch.ones_like(x) / div_up, + atol=1e-4, + rtol=1e-4, + ) + xinv = d.transforms[0].inv(d.mode) + assert d.base_dist._is_equal(xinv).all() + assert d.log_prob(d.mode).shape == x.shape[:-1] + assert (d.log_prob(d.mode) == float("inf")).all() - x = torch.randn(1000000, 4, device=device) - d = TanhDelta(x, -torch.ones_like(x), torch.ones_like(x), atol=1e-4, rtol=1e-4) - xinv = d.transforms[0].inv(d.mode) - assert d.base_dist._is_equal(xinv).all() - assert d.log_prob(d.mode).shape == x.shape[:-1] - assert (d.log_prob(d.mode) == float("inf")).all() + def test_tanhdelta_inv_ones(self, device): + x = torch.randn(1000000, 4, device=device) + d = TanhDelta(x, -torch.ones_like(x), torch.ones_like(x), atol=1e-4, rtol=1e-4) + xinv = d.transforms[0].inv(d.mode) + assert d.base_dist._is_equal(xinv).all() + assert d.log_prob(d.mode).shape == x.shape[:-1] + assert (d.log_prob(d.mode) == float("inf")).all() def _map_all(*tensors_or_other, device): @@ -68,42 +80,43 @@ def _map_all(*tensors_or_other, device): yield t -@pytest.mark.parametrize( - "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] -) -@pytest.mark.parametrize( - "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] -) -@pytest.mark.parametrize( - "vecs", - [ - (torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])), - (torch.zeros(7, 3), torch.ones(7, 3)), - ], -) -@pytest.mark.parametrize( - "upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3] -) -@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_tanhnormal(min, max, vecs, upscale, shape, device): - min, max, vecs, upscale, shape = _map_all( - min, max, vecs, upscale, shape, device=device +class TestTanhNormal: + @pytest.mark.parametrize( + "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] ) - torch.manual_seed(0) - d = TanhNormal( - *vecs, - upscale=upscale, - min=min, - max=max, + @pytest.mark.parametrize( + "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] ) - for _ in range(100): - a = d.rsample(shape) - assert a.shape[: len(shape)] == shape - assert (a >= d.min).all() - assert (a <= d.max).all() - lp = d.log_prob(a) - assert torch.isfinite(lp).all() + @pytest.mark.parametrize( + "vecs", + [ + (torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])), + (torch.zeros(7, 3), torch.ones(7, 3)), + ], + ) + @pytest.mark.parametrize( + "upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3] + ) + @pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_tanhnormal(self, min, max, vecs, upscale, shape, device): + min, max, vecs, upscale, shape = _map_all( + min, max, vecs, upscale, shape, device=device + ) + torch.manual_seed(0) + d = TanhNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + for _ in range(100): + a = d.rsample(shape) + assert a.shape[: len(shape)] == shape + assert (a >= d.min).all() + assert (a <= d.max).all() + lp = d.log_prob(a) + assert torch.isfinite(lp).all() @pytest.mark.parametrize( @@ -124,24 +137,40 @@ def test_tanhnormal(min, max, vecs, upscale, shape, device): ) @pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) @pytest.mark.parametrize("device", get_default_devices()) -def test_truncnormal(min, max, vecs, upscale, shape, device): - torch.manual_seed(0) - min, max, vecs, upscale, shape = _map_all( - min, max, vecs, upscale, shape, device=device - ) - d = TruncatedNormal( - *vecs, - upscale=upscale, - min=min, - max=max, - ) - for _ in range(100): - a = d.rsample(shape) - assert a.shape[: len(shape)] == shape - assert (a >= d.min).all() - assert (a <= d.max).all() - lp = d.log_prob(a) - assert torch.isfinite(lp).all() +class TestTruncatedNormal: + def test_truncnormal(self, min, max, vecs, upscale, shape, device): + torch.manual_seed(0) + min, max, vecs, upscale, shape = _map_all( + min, max, vecs, upscale, shape, device=device + ) + d = TruncatedNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + for _ in range(100): + a = d.rsample(shape) + assert a.shape[: len(shape)] == shape + assert (a >= d.min).all() + assert (a <= d.max).all() + lp = d.log_prob(a) + assert torch.isfinite(lp).all() + + def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device): + torch.manual_seed(0) + min, max, vecs, upscale, shape = _map_all( + min, max, vecs, upscale, shape, device=device + ) + d = TruncatedNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + assert d.mode is not None + assert d.entropy() is not None + assert d.mean is not None @pytest.mark.parametrize( @@ -346,6 +375,203 @@ def test_sample_sparse(self, neg_inf: float) -> None: torch.testing.assert_close(sample_probs, ref_probs, rtol=1e-5, atol=1e-2) +class TestOneHotCategorical: + def test_one_hot(self): + torch.manual_seed(0) + logits = torch.randn(1, 10) + torch.manual_seed(0) + d = OneHotCategorical(logits=logits) + s_a = d.sample((1,)) + torch.manual_seed(0) + d = OneHotCategorical(probs=torch.softmax(logits, -1)) + s_b = d.sample((1,)) + torch.testing.assert_close(s_a, s_b) + assert s_a.dtype == torch.long + assert s_b.dtype == torch.long + assert s_a.sum(-1) == 1 + assert s_b.sum(-1) == 1 + assert s_a.shape[-1] == 10 + assert s_b.shape[-1] == 10 + + @pytest.mark.parametrize( + "reparam", + (ReparamGradientStrategy.PassThrough, ReparamGradientStrategy.RelaxedOneHot), + ) + def test_reparam(self, reparam): + torch.manual_seed(0) + logits = torch.randn(1, 10, requires_grad=True) + torch.manual_seed(0) + d = OneHotCategorical(logits=logits, grad_method=reparam) + s_a = d.rsample((1,)) + torch.manual_seed(0) + d = OneHotCategorical(probs=torch.softmax(logits, -1), grad_method=reparam) + s_b = d.rsample((1,)) + s_a[s_a.detach().bool()].sum().backward() + assert logits.grad is not None and logits.grad.norm() > 0 + logits.grad = None + s_b[s_b.detach().bool()].sum().backward() + assert logits.grad is not None and logits.grad.norm() > 0 + + +class TestMaskedOneHotCategorical: + def test_errs(self): + with pytest.raises( + ValueError, + match="Either `probs` or `logits` must be specified, but not both", + ): + MaskedOneHotCategorical( + logits=torch.tensor(()), probs=torch.tensor(()), mask=torch.tensor(()) + ) + with pytest.raises(ValueError, match="must be provided"): + MaskedOneHotCategorical(probs=torch.tensor(()), mask=None) + with pytest.raises(ValueError, match="must be provided"): + MaskedOneHotCategorical( + probs=torch.tensor(()), mask=torch.tensor(()), indices=torch.tensor(()) + ) + + @pytest.mark.parametrize("neg_inf", [-float(10.0), -float("inf")]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("sparse", [True, False]) + @pytest.mark.parametrize("logits", [True, False]) + def test_construction(self, neg_inf, sparse, logits, device): + torch.manual_seed(0) + logits_vals = torch.randn(4, device=device) / 100 # almost equal probabilities + if logits: + logits = logits_vals + probs = None + else: + probs = logits_vals.softmax(-1) + logits = None + + if sparse: + indices = torch.tensor([0, 2, 3], device=device) + mask = None + else: + mask = torch.tensor([True, False, True, True], device=device) + indices = None + dist = MaskedOneHotCategorical( + logits=logits, probs=probs, indices=indices, mask=mask, neg_inf=neg_inf + ) + dist_categ = MaskedCategorical( + logits=logits, probs=probs, indices=indices, mask=mask, neg_inf=neg_inf + ) + for _ in range(10): + sample = dist.sample((100,)) + assert not sample[..., 1].any() + assert torch.isfinite(dist.log_prob(sample)).all() + torch.testing.assert_close( + dist.log_prob(sample), dist_categ.log_prob(sample.argmax(-1)) + ) + assert sample.device == device + + sample_unfeasible = torch.zeros_like(sample) + sample_unfeasible[..., 1] = 1 + if neg_inf == -float("inf"): + assert (dist.log_prob(sample_unfeasible) == neg_inf).all() + else: + assert (dist.log_prob(sample_unfeasible) > -float("inf")).all() + + @pytest.mark.parametrize("neg_inf", [-float(10.0), -float("inf")]) + @pytest.mark.parametrize("sparse", [True, False]) + @pytest.mark.parametrize("logits", [True, False]) + def test_backprop(self, neg_inf, sparse, logits): + torch.manual_seed(0) + logits_vals = ( + torch.randn(4).div_(100).requires_grad_() + ) # almost equal probabilities + if logits: + logits = logits_vals + probs = None + else: + probs = logits_vals.softmax(-1) + logits = None + + if sparse: + indices = torch.tensor([0, 2, 3]) + mask = None + else: + mask = torch.tensor([True, False, True, True]) + indices = None + dist = MaskedOneHotCategorical( + logits=logits, probs=probs, indices=indices, mask=mask, neg_inf=neg_inf + ) + sample = dist.sample((100,)) + lp = dist.log_prob(sample) + lp.sum().backward() + assert logits_vals.grad is not None + + @pytest.mark.parametrize("neg_inf", [-1e20, float("-inf")]) + def test_sample(self, neg_inf: float) -> None: + torch.manual_seed(0) + logits = torch.randn(4) + probs = F.softmax(logits, dim=-1) + mask = torch.tensor([True, False, True, True]) + ref_probs = probs.masked_fill(~mask, 0.0) + ref_probs /= ref_probs.sum(dim=-1, keepdim=True) + + dist = MaskedOneHotCategorical( + probs=probs, + mask=mask, + neg_inf=neg_inf, + ) + num_samples = 10000 + samples = dist.sample([num_samples]).argmax(-1) + sample_probs = torch.bincount(samples) / num_samples + torch.testing.assert_close(sample_probs, ref_probs, rtol=1e-5, atol=1e-2) + + @pytest.mark.parametrize("neg_inf", [-1e20, float("-inf")]) + def test_sample_sparse(self, neg_inf: float) -> None: + torch.manual_seed(0) + logits = torch.randn(4) + probs = F.softmax(logits, dim=-1) + mask = torch.tensor([True, False, True, True]) + indices = torch.tensor([0, 2, 3]) + ref_probs = probs.masked_fill(~mask, 0.0) + ref_probs /= ref_probs.sum(dim=-1, keepdim=True) + + dist = MaskedOneHotCategorical( + logits=logits, + indices=indices, + neg_inf=neg_inf, + ) + num_samples = 10000 + samples = dist.sample([num_samples]).argmax(-1) + sample_probs = torch.bincount(samples) / num_samples + torch.testing.assert_close(sample_probs, ref_probs, rtol=1e-5, atol=1e-2) + + @pytest.mark.parametrize( + "grad_method", + [ReparamGradientStrategy.RelaxedOneHot, ReparamGradientStrategy.PassThrough], + ) + @pytest.mark.parametrize("sparse", [True, False]) + def test_reparam(self, grad_method, sparse): + torch.manual_seed(0) + neg_inf = -float("inf") + logits = torch.randn(100, requires_grad=True) + probs = F.softmax(logits, dim=-1) + # mask = torch.tensor([True, False, True, True]) + # indices = torch.tensor([0, 2, 3]) + if sparse: + indices = torch.randint(100, (70,)).unique().view(-1) + mask = None + else: + mask = torch.zeros(100, dtype=torch.bool).bernoulli_() + indices = None + + dist = MaskedOneHotCategorical( + logits=logits, + indices=indices, + neg_inf=neg_inf, + grad_method=grad_method, + mask=mask, + ) + + s = dist.rsample() + assert s.shape[-1] == 100 + s[s.detach().bool()].sum().backward() + assert logits.grad is not None and logits.grad.norm() > 0 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_env.py b/test/test_env.py index a52f140fcd7..70fe4bec37a 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -5,8 +5,10 @@ import argparse import os.path +import re from collections import defaultdict from functools import partial +from sys import platform import numpy as np import pytest @@ -16,10 +18,13 @@ from _utils_internal import ( _make_envs, CARTPOLE_VERSIONED, + check_rollout_consistency_multikey_env, + decorate_thread_sub_func, get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED, + rand_reset, ) from mocking_classes import ( ActionObsMergeLinear, @@ -33,19 +38,26 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + HeteroCountingEnv, + HeteroCountingEnvPolicy, MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, NestedCountingEnv, ) from packaging import version +from tensordict import dense_stack_tds from tensordict.nn import TensorDictModuleBase -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict.tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict +from tensordict.utils import _unravel_key_to_tuple from torch import nn from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data.tensor_specs import ( CompositeSpec, + DiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, ) @@ -54,7 +66,14 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv -from torchrl.envs.utils import check_env_specs, make_composite_from_td, step_mdp +from torchrl.envs.utils import ( + check_env_specs, + check_marl_grouping, + make_composite_from_td, + MarlGroupMapType, + step_mdp, + terminated_or_truncated, +) from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator from torchrl.modules.tensordict_module import WorldModelWrapper @@ -76,6 +95,7 @@ _atari_found = False atari_confs = defaultdict(lambda: "") +IS_OSX = platform == "darwin" ## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between ## the serial and parallel batched envs @@ -136,7 +156,7 @@ def test_env_seed(env_name, frame_skip, seed=0): env.set_seed(seed) td0b = env.fake_tensordict() td0b = env.reset(tensordict=td0b) - td1b = env.step(td0b.clone().set("action", action)) + td1b = env.step(td0b.exclude("next").clone().set("action", action)) assert_allclose_td(td0a, td0b.select(*td0a.keys())) assert_allclose_td(td1a, td1b) @@ -213,20 +233,10 @@ def test_rollout_predictability(device): @pytest.mark.skipif(not _has_gym, reason="no gym") -@pytest.mark.parametrize( - "env_name", - [ - PENDULUM_VERSIONED, - ], -) -@pytest.mark.parametrize( - "frame_skip", - [ - 1, - ], -) +@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED]) +@pytest.mark.parametrize("frame_skip", [1]) @pytest.mark.parametrize("truncated_key", ["truncated", "done"]) -@pytest.mark.parametrize("parallel", [True, False]) +@pytest.mark.parametrize("parallel", [False, True]) def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0): envs = [] for horizon in [20, 30, 40]: @@ -244,6 +254,12 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0): out = env.rollout(100, break_when_any_done=False) assert out.names[-1] == "time" assert out.shape == torch.Size([3, 100]) + assert ( + out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19]) + ).all() + assert ( + out[..., -1]["next", "step_count"].squeeze().cpu() == torch.tensor([20, 10, 20]) + ).all() assert ( out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2]) ).all() @@ -276,15 +292,16 @@ def test_mb_rollout(self, device, seed=0): check_env_specs(mb_env) rollout = mb_env.rollout(max_steps=100) expected_keys = { - ("next", key) for key in (*mb_env.observation_spec.keys(), "reward", "done") + ("next", key) + for key in (*mb_env.observation_spec.keys(), "reward", "done", "terminated") } expected_keys = expected_keys.union( - set(mb_env.input_spec["_action_spec"].keys()) + set(mb_env.input_spec["full_action_spec"].keys()) ) expected_keys = expected_keys.union( - set(mb_env.input_spec["_state_spec"].keys()) + set(mb_env.input_spec["full_state_spec"].keys()) ) - expected_keys = expected_keys.union({"done", "next"}) + expected_keys = expected_keys.union({"done", "terminated", "next"}) assert set(rollout.keys(True)) == expected_keys assert rollout[("next", "hidden_observation")].shape == (10, 100, 4) @@ -316,7 +333,9 @@ def test_mb_env_batch_lock(self, device, seed=0): td_expanded = td.unsqueeze(-1).expand(10, 2).reshape(-1).to_tensordict() mb_env.step(td) - with pytest.raises(RuntimeError, match="Expected a tensordict with shape"): + with pytest.raises( + RuntimeError, match=re.escape("Expected a tensordict with shape==env.shape") + ): mb_env.step(td_expanded) mb_env = DummyModelBasedEnvBase( @@ -458,12 +477,7 @@ def test_parallel_env( transformed_out=transformed_out, N=N, ) - td = TensorDict( - source={"action": env0.action_spec.rand((N,))}, - batch_size=[ - N, - ], - ) + td = TensorDict(source={"action": env0.action_spec.rand((N,))}, batch_size=[N]) td1 = env_parallel.step(td) assert not td1.is_shared() assert ("next", "done") in td1.keys(True) @@ -477,12 +491,7 @@ def test_parallel_env( ) _ = env_parallel.step(td) - td_reset = TensorDict( - source={"_reset": env_parallel.done_spec.rand()}, - batch_size=[ - N, - ], - ) + td_reset = TensorDict(source=rand_reset(env_parallel), batch_size=[N]) env_parallel.reset(tensordict=td_reset) td = env_parallel.rollout(policy=None, max_steps=T) @@ -534,12 +543,7 @@ def test_parallel_env_with_policy( ), ) - td = TensorDict( - source={"action": env0.action_spec.rand((N,))}, - batch_size=[ - N, - ], - ) + td = TensorDict(source={"action": env0.action_spec.rand((N,))}, batch_size=[N]) td1 = env_parallel.step(td) assert not td1.is_shared() assert ("next", "done") in td1.keys(True) @@ -553,12 +557,7 @@ def test_parallel_env_with_policy( ) _ = env_parallel.step(td) - td_reset = TensorDict( - source={"_reset": env_parallel.done_spec.rand()}, - batch_size=[ - N, - ], - ) + td_reset = TensorDict(source=rand_reset(env_parallel), batch_size=[N]) env_parallel.reset(tensordict=td_reset) td = env_parallel.rollout(policy=policy, max_steps=T) @@ -896,15 +895,7 @@ def env_fn2(seed): env1.close() env2.close() - @pytest.mark.parametrize( - "batch_size", - [ - (32, 5), - (4,), - (1,), - (), - ], - ) + @pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()]) @pytest.mark.parametrize("n_workers", [2, 1]) def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): torch.manual_seed(1) @@ -929,19 +920,17 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): assert (td["next", "done"] == 1).all() assert (td["next"]["observation"] == max_steps + 1).all() - _reset = env.done_spec.rand() - while not _reset.any(): - _reset = env.done_spec.rand() - - td_reset = env.reset( - TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) + td_reset = TensorDict( + rand_reset(env), batch_size=env.batch_size, device=env.device ) + reset = td_reset["_reset"] + td_reset = env.reset(td_reset) env.close() - assert (td_reset["done"][_reset] == 0).all() - assert (td_reset["observation"][_reset] == 0).all() - assert (td_reset["done"][~_reset] == 1).all() - assert (td_reset["observation"][~_reset] == max_steps + 1).all() + assert (td_reset["done"][reset] == 0).all() + assert (td_reset["observation"][reset] == 0).all() + assert (td_reset["done"][~reset] == 1).all() + assert (td_reset["observation"][~reset] == max_steps + 1).all() @pytest.mark.parametrize("nested_obs_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) @@ -1006,6 +995,7 @@ def test_parallel_env_nested( @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) def test_env_base_reset_flag(batch_size, max_steps=3): + torch.manual_seed(0) env = CountingEnv(max_steps=max_steps, batch_size=batch_size) env.set_seed(1) @@ -1025,15 +1015,14 @@ def test_env_base_reset_flag(batch_size, max_steps=3): assert (td["next", "done"] == 1).all() assert (td["next", "observation"] == max_steps + 1).all() - _reset = env.done_spec.rand() - td_reset = env.reset( - TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) - ) + td_reset = TensorDict(rand_reset(env), batch_size=env.batch_size, device=env.device) + reset = td_reset["_reset"] + td_reset = env.reset(td_reset) - assert (td_reset["done"][_reset] == 0).all() - assert (td_reset["observation"][_reset] == 0).all() - assert (td_reset["done"][~_reset] == 1).all() - assert (td_reset["observation"][~_reset] == max_steps + 1).all() + assert (td_reset["done"][reset] == 0).all() + assert (td_reset["observation"][reset] == 0).all() + assert (td_reset["done"][~reset] == 1).all() + assert (td_reset["observation"][~reset] == max_steps + 1).all() @pytest.mark.skipif(not _has_gym, reason="no gym") @@ -1229,9 +1218,9 @@ def test_nested( exclude_reward=exclude_reward, exclude_done=exclude_done, exclude_action=exclude_action, - reward_key=reward_key, - done_key=done_key, - action_key=action_key, + reward_keys=reward_key, + done_keys=done_key, + action_keys=action_key, keep_other=keep_other, ) td_nested_keys = td.keys(True, True) @@ -1330,9 +1319,9 @@ def test_nested_partially( exclude_reward=exclude_reward, exclude_done=exclude_done, exclude_action=exclude_action, - reward_key=reward_key, - done_key=done_key, - action_key=action_key, + reward_keys=reward_key, + done_keys=done_key, + action_keys=action_key, keep_other=keep_other, ) td_keys_nested = td.keys(True, True) @@ -1369,9 +1358,9 @@ def test_nested_partially( exclude_reward=exclude_reward, exclude_done=exclude_done, exclude_action=exclude_action, - reward_key=reward_key, - done_key=done_key, - action_key=action_key, + reward_keys=reward_key, + done_keys=done_key, + action_keys=action_key, keep_other=keep_other, ) td_keys = td.keys() @@ -1383,6 +1372,146 @@ def test_nested_partially( assert nested_key[0] not in td_keys assert (td[other_key] == 0).all() + @pytest.mark.parametrize("het_action", [True, False]) + @pytest.mark.parametrize("het_done", [True, False]) + @pytest.mark.parametrize("het_reward", [True, False]) + @pytest.mark.parametrize("het_other", [True, False]) + @pytest.mark.parametrize("het_obs", [True, False]) + @pytest.mark.parametrize("exclude_reward", [True, False]) + @pytest.mark.parametrize("exclude_done", [True, False]) + @pytest.mark.parametrize("exclude_action", [True, False]) + @pytest.mark.parametrize("keep_other", [True, False]) + def test_heterogeenous( + self, + het_action, + het_done, + het_reward, + het_other, + het_obs, + exclude_reward, + exclude_done, + exclude_action, + keep_other, + ): + td_batch_size = 4 + nested_dim = 3 + nested_batch_size = (td_batch_size, nested_dim) + nested_key = ("data",) + + reward_key = "reward" + nested_reward_key = nested_key + (reward_key,) + done_key = "done" + nested_done_key = nested_key + (done_key,) + action_key = "action" + nested_action_key = nested_key + (action_key,) + obs_key = "state" + nested_obs_key = nested_key + (obs_key,) + other_key = "beatles" + nested_other_key = nested_key + (other_key,) + + tds = [] + for i in range(1, nested_dim + 1): + tds.append( + TensorDict( + { + nested_key: TensorDict( + { + reward_key: torch.zeros( + td_batch_size, i if het_reward else 1 + ), + done_key: torch.zeros( + td_batch_size, i if het_done else 1 + ), + action_key: torch.zeros( + td_batch_size, i if het_action else 1 + ), + obs_key: torch.zeros( + td_batch_size, i if het_obs else 1 + ), + other_key: torch.zeros( + td_batch_size, i if het_other else 1 + ), + }, + [td_batch_size], + ), + "next": { + nested_key: TensorDict( + { + reward_key: torch.ones( + td_batch_size, i if het_reward else 1 + ), + done_key: torch.ones( + td_batch_size, i if het_done else 1 + ), + obs_key: torch.ones( + td_batch_size, i if het_obs else 1 + ), + }, + [td_batch_size], + ), + }, + }, + [td_batch_size], + ) + ) + lazy_td = torch.stack(tds, dim=1) + input_td = lazy_td + + td = step_mdp( + lazy_td.lock_(), + exclude_reward=exclude_reward, + exclude_done=exclude_done, + exclude_action=exclude_action, + reward_keys=nested_reward_key, + done_keys=nested_done_key, + action_keys=nested_action_key, + keep_other=keep_other, + ) + td_nested_keys = td.keys(True, True) + td_keys = td.keys() + for i in range(nested_dim): + if het_obs: + assert td[..., i][nested_obs_key].shape == (td_batch_size, i + 1) + else: + assert td[..., i][nested_obs_key].shape == (td_batch_size, 1) + assert (td[..., i][nested_obs_key] == 1).all() + if exclude_reward: + assert nested_reward_key not in td_keys + else: + for i in range(nested_dim): + if het_reward: + assert td[..., i][nested_reward_key].shape == (td_batch_size, i + 1) + else: + assert td[..., i][nested_reward_key].shape == (td_batch_size, 1) + assert (td[..., i][nested_reward_key] == 1).all() + if exclude_done: + assert nested_done_key not in td_keys + else: + for i in range(nested_dim): + if het_done: + assert td[..., i][nested_done_key].shape == (td_batch_size, i + 1) + else: + assert td[..., i][nested_done_key].shape == (td_batch_size, 1) + assert (td[..., i][nested_done_key] == 1).all() + if exclude_action: + assert nested_action_key not in td_keys + else: + for i in range(nested_dim): + if het_action: + assert td[..., i][nested_action_key].shape == (td_batch_size, i + 1) + else: + assert td[..., i][nested_action_key].shape == (td_batch_size, 1) + assert (td[..., i][nested_action_key] == 0).all() + if not keep_other: + assert nested_other_key not in td_keys + else: + for i in range(nested_dim): + if het_other: + assert td[..., i][nested_other_key].shape == (td_batch_size, i + 1) + else: + assert td[..., i][nested_other_key].shape == (td_batch_size, 1) + assert (td[..., i][nested_other_key] == 0).all() + @pytest.mark.parametrize("device", get_default_devices()) def test_batch_locked(device): @@ -1430,9 +1559,7 @@ def test_batch_unlocked_with_batch_size(device): td_expanded = td.expand(2, 2).reshape(-1).to_tensordict() td = env.step(td) - with pytest.raises( - RuntimeError, match="Expected a tensordict with shape==env.shape, " - ): + with pytest.raises(RuntimeError, match="Expected a tensordict with shape"): env.step(td_expanded) @@ -1493,6 +1620,23 @@ def test_make_spec_from_td(): assert val.dtype is spec[key].dtype +@pytest.mark.parametrize("group_type", list(MarlGroupMapType)) +def test_marl_group_type(group_type): + agent_names = ["agent"] + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + agent_names = ["agent", "agent"] + with pytest.raises(ValueError): + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + agent_names = ["agent_0", "agent_1"] + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + agent_names = [] + with pytest.raises(ValueError): + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device") class TestConcurrentEnvs: """Concurrent parallel envs on multiple procs can interfere.""" @@ -1669,7 +1813,6 @@ def test_mp_collector(self, nproc): class TestNestedSpecs: @pytest.mark.parametrize("envclass", ["CountingEnv", "NestedCountingEnv"]) def test_nested_env(self, envclass): - if envclass == "CountingEnv": env = CountingEnv() elif envclass == "NestedCountingEnv": @@ -1677,17 +1820,27 @@ def test_nested_env(self, envclass): else: raise NotImplementedError reset = env.reset() - assert not isinstance(env.done_spec, CompositeSpec) assert not isinstance(env.reward_spec, CompositeSpec) - assert env.done_spec == env.output_spec[("_done_spec", *env.done_key)] - assert env.reward_spec == env.output_spec[("_reward_spec", *env.reward_key)] + for done_key in env.done_keys: + assert ( + env.full_done_spec[done_key] + == env.output_spec[("full_done_spec", *_unravel_key_to_tuple(done_key))] + ) + assert ( + env.reward_spec + == env.output_spec[ + ("full_reward_spec", *_unravel_key_to_tuple(env.reward_key)) + ] + ) if envclass == "NestedCountingEnv": - assert env.done_key == ("data", "done") + for done_key in env.done_keys: + assert done_key in (("data", "done"), ("data", "terminated")) assert env.reward_key == ("data", "reward") assert ("data", "done") in reset.keys(True) assert ("data", "states") in reset.keys(True) assert ("data", "reward") not in reset.keys(True) - assert env.done_key in reset.keys(True) + for done_key in env.done_keys: + assert done_key in reset.keys(True) assert env.reward_key not in reset.keys(True) next_state = env.rand_step() @@ -1695,12 +1848,12 @@ def test_nested_env(self, envclass): assert ("next", "data", "done") in next_state.keys(True) assert ("next", "data", "states") in next_state.keys(True) assert ("next", "data", "reward") in next_state.keys(True) - assert ("next", *env.done_key) in next_state.keys(True) - assert ("next", *env.reward_key) in next_state.keys(True) + for done_key in env.done_keys: + assert ("next", *_unravel_key_to_tuple(done_key)) in next_state.keys(True) + assert ("next", *_unravel_key_to_tuple(env.reward_key)) in next_state.keys(True) @pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)]) def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): - env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) td_reset = env.reset() @@ -1750,6 +1903,130 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): ) +class TestHeteroEnvs: + @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) + def test_reset(self, batch_size): + env = HeteroCountingEnv(batch_size=batch_size) + env.reset() + + @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) + def test_rand_step(self, batch_size): + env = HeteroCountingEnv(batch_size=batch_size) + td = env.reset() + assert (td["lazy"][..., 0]["tensor_0"] == 0).all() + td = env.rand_step() + assert (td["next", "lazy"][..., 0]["tensor_0"] == 1).all() + td = env.rand_step() + assert (td["next", "lazy"][..., 1]["tensor_1"] == 2).all() + + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + @pytest.mark.parametrize("rollout_steps", [1, 2, 5]) + def test_rollout(self, batch_size, rollout_steps, n_lazy_dim=3): + env = HeteroCountingEnv(batch_size=batch_size) + td = env.rollout(rollout_steps, return_contiguous=False) + td = dense_stack_tds(td) + + assert isinstance(td, TensorDict) + assert td.batch_size == (*batch_size, rollout_steps) + + assert isinstance(td["lazy"], LazyStackedTensorDict) + assert td["lazy"].shape == (*batch_size, rollout_steps, n_lazy_dim) + assert td["lazy"].stack_dim == len(td["lazy"].batch_size) - 1 + + assert (td[..., -1]["next", "state"] == rollout_steps).all() + assert (td[..., -1]["next", "lazy", "camera"] == rollout_steps).all() + assert ( + td["lazy"][(0,) * len(batch_size)][..., 0]["tensor_0"].squeeze(-1) + == torch.arange(rollout_steps) + ).all() + + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + @pytest.mark.parametrize("rollout_steps", [1, 2, 5]) + @pytest.mark.parametrize("count", [True, False]) + def test_rollout_policy(self, batch_size, rollout_steps, count): + env = HeteroCountingEnv(batch_size=batch_size) + policy = HeteroCountingEnvPolicy( + env.input_spec["full_action_spec"], count=count + ) + td = env.rollout(rollout_steps, policy=policy, return_contiguous=False) + td = dense_stack_tds(td) + for i in range(env.n_nested_dim): + if count: + agent_obs = td["lazy"][(0,) * len(batch_size)][..., i][f"tensor_{i}"] + for _ in range(i + 1): + agent_obs = agent_obs.mean(-1) + assert (agent_obs == torch.arange(rollout_steps)).all() + assert (td["lazy"][..., i]["action"] == 1).all() + else: + assert (td["lazy"][..., i]["action"] == 0).all() + + @pytest.mark.parametrize("batch_size", [(1, 2)]) + @pytest.mark.parametrize("env_type", ["serial", "parallel"]) + def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): + env_fun = lambda: HeteroCountingEnv(batch_size=batch_size) + if env_type == "serial": + vec_env = SerialEnv(n_workers, env_fun) + else: + vec_env = ParallelEnv(n_workers, env_fun) + vec_batch_size = (n_workers,) + batch_size + # check_env_specs(vec_env, return_contiguous=False) + policy = HeteroCountingEnvPolicy(vec_env.input_spec["full_action_spec"]) + vec_env.reset() + td = vec_env.rollout( + rollout_steps, + policy=policy, + return_contiguous=False, + break_when_any_done=False, + ) + td = dense_stack_tds(td) + for i in range(env_fun().n_nested_dim): + agent_obs = td["lazy"][(0,) * len(vec_batch_size)][..., i][f"tensor_{i}"] + for _ in range(i + 1): + agent_obs = agent_obs.mean(-1) + assert (agent_obs == torch.arange(rollout_steps)).all() + assert (td["lazy"][..., i]["action"] == 1).all() + + +@pytest.mark.parametrize("seed", [0]) +class TestMultiKeyEnvs: + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + @pytest.mark.parametrize("rollout_steps", [1, 5]) + @pytest.mark.parametrize("max_steps", [2, 5]) + def test_rollout(self, batch_size, rollout_steps, max_steps, seed): + env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps) + policy = MultiKeyCountingEnvPolicy(full_action_spec=env.action_spec) + td = env.rollout(rollout_steps, policy=policy) + torch.manual_seed(seed) + check_rollout_consistency_multikey_env(td, max_steps=max_steps) + + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + @pytest.mark.parametrize("rollout_steps", [5]) + @pytest.mark.parametrize("env_type", ["serial", "parallel"]) + @pytest.mark.parametrize("max_steps", [2, 5]) + def test_parallel( + self, batch_size, rollout_steps, env_type, max_steps, seed, n_workers=2 + ): + torch.manual_seed(seed) + env_fun = lambda: MultiKeyCountingEnv( + batch_size=batch_size, max_steps=max_steps + ) + if env_type == "serial": + vec_env = SerialEnv(n_workers, env_fun) + else: + vec_env = ParallelEnv(n_workers, env_fun) + + # check_env_specs(vec_env) + policy = MultiKeyCountingEnvPolicy( + full_action_spec=vec_env.input_spec["full_action_spec"] + ) + vec_env.reset() + td = vec_env.rollout( + rollout_steps, + policy=policy, + ) + check_rollout_consistency_multikey_env(td, max_steps=max_steps) + + @pytest.mark.parametrize( "envclass", [ @@ -1768,6 +2045,8 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): MockBatchedUnLockedEnv, MockSerialEnv, NestedCountingEnv, + HeteroCountingEnv, + MultiKeyCountingEnv, ], ) def test_mocking_envs(envclass): @@ -1775,7 +2054,187 @@ def test_mocking_envs(envclass): env.set_seed(100) reset = env.reset() _ = env.rand_step(reset) - check_env_specs(env, seed=100) + check_env_specs(env, seed=100, return_contiguous=False) + + +class TestTerminatedOrTruncated: + def test_terminated_or_truncated_nospec(self): + data = TensorDict({"done": torch.zeros(2, 1, dtype=torch.bool)}, [2]) + assert not terminated_or_truncated(data, write_full_false=True) + assert data["_reset"].shape == (2, 1) + assert not terminated_or_truncated(data, write_full_false=False) + assert data.get("_reset", None) is None + + data = TensorDict( + { + "done": torch.zeros(2, 1, dtype=torch.bool), + ("nested", "done"): torch.ones(2, 1, dtype=torch.bool), + }, + [2], + ) + assert terminated_or_truncated(data) + assert data["_reset"].shape == (2, 1) + assert data["nested", "_reset"].shape == (2, 1) + + data = TensorDict( + { + "done": torch.zeros(2, 1, dtype=torch.bool), + ("nested", "done"): torch.zeros(2, 1, dtype=torch.bool), + }, + [2], + ) + assert not terminated_or_truncated(data, write_full_false=False) + assert data.get("_reset", None) is None + assert data.get(("nested", "_reset"), None) is None + assert not terminated_or_truncated(data, write_full_false=True) + assert data["_reset"].shape == (2, 1) + assert data["nested", "_reset"].shape == (2, 1) + + data = TensorDict( + { + "terminated": torch.zeros(2, 1, dtype=torch.bool), + "truncated": torch.ones(2, 1, dtype=torch.bool), + ("nested", "terminated"): torch.zeros(2, 1, dtype=torch.bool), + }, + [2], + ) + assert terminated_or_truncated(data, write_full_false=False) + assert data["_reset"].shape == (2, 1) + assert data["nested", "_reset"].shape == (2, 1) + assert data["_reset"].all() + assert not data["nested", "_reset"].any() + + def test_terminated_or_truncated_spec(self): + spec = CompositeSpec( + done=DiscreteTensorSpec(2, shape=(2, 1), dtype=torch.bool), + shape=[ + 2, + ], + ) + data = TensorDict({"done": torch.zeros(2, 1, dtype=torch.bool)}, [2]) + assert not terminated_or_truncated( + data, write_full_false=True, full_done_spec=spec + ) + assert data["_reset"].shape == (2, 1) + assert not terminated_or_truncated( + data, write_full_false=False, full_done_spec=spec + ) + assert data.get("_reset", None) is None + + spec = CompositeSpec( + { + "done": DiscreteTensorSpec(2, shape=(2, 1), dtype=torch.bool), + ("nested", "done"): DiscreteTensorSpec( + 2, shape=(2, 1), dtype=torch.bool + ), + }, + shape=[ + 2, + ], + ) + data = TensorDict( + { + "done": torch.zeros(2, 1, dtype=torch.bool), + ("nested", "done"): torch.ones(2, 1, dtype=torch.bool), + }, + [2], + ) + assert terminated_or_truncated(data, full_done_spec=spec) + assert data["_reset"].shape == (2, 1) + assert data["nested", "_reset"].shape == (2, 1) + + data = TensorDict( + { + "done": torch.zeros(2, 1, dtype=torch.bool), + ("nested", "done"): torch.zeros(2, 1, dtype=torch.bool), + }, + [2], + ) + assert not terminated_or_truncated( + data, write_full_false=False, full_done_spec=spec + ) + assert data.get("_reset", None) is None + assert data.get(("nested", "_reset"), None) is None + assert not terminated_or_truncated( + data, write_full_false=True, full_done_spec=spec + ) + assert data["_reset"].shape == (2, 1) + assert data["nested", "_reset"].shape == (2, 1) + + spec = CompositeSpec( + { + "truncated": DiscreteTensorSpec(2, shape=(2, 1), dtype=torch.bool), + "terminated": DiscreteTensorSpec(2, shape=(2, 1), dtype=torch.bool), + ("nested", "terminated"): DiscreteTensorSpec( + 2, shape=(2, 1), dtype=torch.bool + ), + }, + shape=[2], + ) + data = TensorDict( + { + "terminated": torch.zeros(2, 1, dtype=torch.bool), + "truncated": torch.ones(2, 1, dtype=torch.bool), + ("nested", "terminated"): torch.zeros(2, 1, dtype=torch.bool), + }, + [2], + ) + assert terminated_or_truncated( + data, write_full_false=False, full_done_spec=spec + ) + assert data["_reset"].shape == (2, 1) + assert data["nested", "_reset"].shape == (2, 1) + assert data["_reset"].all() + assert not data["nested", "_reset"].any() + + +@pytest.mark.skipif( + IS_OSX, reason="setting different threads across workeres can randomly fail on OSX." +) +def test_num_threads(): + from torchrl.envs import batched_envs + + _run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem + batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func( + batched_envs._run_worker_pipe_shared_mem, num_threads=3 + ) + num_threads = torch.get_num_threads() + try: + env = ParallelEnv( + 2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7 + ) + # We could test that the number of threads isn't changed until we start the procs. + # Even though it's unlikely that we have 7 threads, we still disable this for safety + # assert torch.get_num_threads() != 7 + env.rollout(3) + assert torch.get_num_threads() == 7 + finally: + # reset vals + batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save + torch.set_num_threads(num_threads) + + +def test_run_type_checks(): + env = ContinuousActionVecMockEnv() + env._run_type_checks = False + check_env_specs(env) + env._run_type_checks = True + check_env_specs(env) + env.output_spec.unlock_() + # check type check on done + env.output_spec["full_done_spec", "done"].dtype = torch.int + with pytest.raises(TypeError, match="expected done.dtype to"): + check_env_specs(env) + env.output_spec["full_done_spec", "done"].dtype = torch.bool + # check type check on reward + env.output_spec["full_reward_spec", "reward"].dtype = torch.int + with pytest.raises(TypeError, match="expected"): + check_env_specs(env) + env.output_spec["full_reward_spec", "reward"].dtype = torch.float + # check type check on obs + env.output_spec["full_observation_spec", "observation"].dtype = torch.float16 + with pytest.raises(TypeError): + check_env_specs(env) if __name__ == "__main__": diff --git a/test/test_exploration.py b/test/test_exploration.py index c823dbaf4f4..0caf93824ce 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,15 +14,21 @@ NestedCountingEnv, ) from scipy.stats import ttest_1samp -from tensordict.nn import InteractionType, TensorDictModule + +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from tensordict.tensordict import TensorDict from torch import nn from torchrl.collectors import SyncDataCollector -from torchrl.data import BoundedTensorSpec, CompositeSpec +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, +) from torchrl.envs import SerialEnv from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv -from torchrl.envs.utils import set_exploration_type +from torchrl.envs.utils import _replace_last, set_exploration_type from torchrl.modules import SafeModule, SafeSequential from torchrl.modules.distributions import TanhNormal from torchrl.modules.distributions.continuous import ( @@ -30,23 +36,37 @@ NormalParamWrapper, ) from torchrl.modules.models.exploration import LazygSDEModule -from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor +from torchrl.modules.tensordict_module.actors import ( + Actor, + ProbabilisticActor, + QValueActor, +) from torchrl.modules.tensordict_module.exploration import ( _OrnsteinUhlenbeckProcess, AdditiveGaussianWrapper, + EGreedyModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) -@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) class TestEGreedy: - def test_egreedy(self, eps_init): + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize("module", [True, False]) + def test_egreedy(self, eps_init, module): torch.manual_seed(0) spec = BoundedTensorSpec(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 4, bias=False) + policy = Actor(spec=spec, module=module) - explorative_policy = EGreedyWrapper(policy, eps_init=eps_init, eps_end=eps_init) + if module: + explorative_policy = TensorDictSequential( + policy, EGreedyModule(eps_init=eps_init, eps_end=eps_init, spec=spec) + ) + else: + explorative_policy = EGreedyWrapper( + policy, eps_init=eps_init, eps_end=eps_init + ) td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) action = explorative_policy(td).get("action") if eps_init == 0: @@ -58,6 +78,135 @@ def test_egreedy(self, eps_init): assert (action == 0).any() assert ((action == 1) | (action == 0)).all() + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize("module", [True, False]) + @pytest.mark.parametrize("spec_class", ["discrete", "one_hot"]) + def test_egreedy_masked(self, module, eps_init, spec_class): + torch.manual_seed(0) + action_size = 4 + batch_size = (3, 4, 2) + module = torch.nn.Linear(action_size, action_size, bias=False) + if spec_class == "discrete": + spec = DiscreteTensorSpec(action_size) + else: + spec = OneHotDiscreteTensorSpec( + action_size, + shape=(action_size,), + ) + policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask") + if module: + explorative_policy = TensorDictSequential( + policy, + EGreedyModule( + eps_init=eps_init, + eps_end=eps_init, + spec=spec, + action_mask_key="action_mask", + ), + ) + else: + explorative_policy = EGreedyWrapper( + policy, + eps_init=eps_init, + eps_end=eps_init, + action_mask_key="action_mask", + ) + + td = TensorDict( + {"observation": torch.zeros(*batch_size, action_size)}, + batch_size=batch_size, + ) + with pytest.raises(KeyError, match="Action mask key action_mask not found in"): + explorative_policy(td) + + torch.manual_seed(0) + action_mask = torch.ones(*batch_size, action_size).to(torch.bool) + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + "action_mask": action_mask, + }, + batch_size=batch_size, + ) + action = explorative_policy(td).get("action") + + torch.manual_seed(0) + action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to( + torch.bool + ) + while not action_mask.any(dim=-1).all() or action_mask.all(): + action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to( + torch.bool + ) + + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + "action_mask": action_mask, + }, + batch_size=batch_size, + ) + masked_action = explorative_policy(td).get("action") + + if spec_class == "discrete": + action = spec.to_one_hot(action) + masked_action = spec.to_one_hot(masked_action) + + assert not (action[~action_mask] == 0).all() + assert (masked_action[~action_mask] == 0).all() + + def test_egreedy_wrapper_deprecation(self): + torch.manual_seed(0) + spec = BoundedTensorSpec(1, 1, torch.Size([4])) + module = torch.nn.Linear(4, 4, bias=False) + policy = Actor(spec=spec, module=module) + with pytest.deprecated_call(): + EGreedyWrapper(policy) + + def test_no_spec_error( + self, + ): + torch.manual_seed(0) + action_size = 4 + batch_size = (3, 4, 2) + module = torch.nn.Linear(action_size, action_size, bias=False) + spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,)) + policy = QValueActor(spec=spec, module=module) + explorative_policy = TensorDictSequential( + policy, + EGreedyModule(spec=None), + ) + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + }, + batch_size=batch_size, + ) + + with pytest.raises( + RuntimeError, match="spec must be provided to the exploration wrapper." + ): + explorative_policy(td) + + @pytest.mark.parametrize("module", [True, False]) + def test_wrong_action_shape(self, module): + torch.manual_seed(0) + spec = BoundedTensorSpec(1, 1, torch.Size([4])) + module = torch.nn.Linear(4, 5, bias=False) + + policy = Actor(spec=spec, module=module) + if module: + explorative_policy = TensorDictSequential(policy, EGreedyModule(spec=spec)) + else: + explorative_policy = EGreedyWrapper( + policy, + ) + td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + with pytest.raises( + ValueError, match="Action spec shape does not match the action shape" + ): + explorative_policy(td) + @pytest.mark.parametrize("device", get_default_devices()) class TestOrnsteinUhlenbeckProcessWrapper: @@ -186,7 +335,7 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): @pytest.mark.parametrize("nested_obs_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) - @pytest.mark.parametrize("is_init_key", ["some", ("one", "nested")]) + @pytest.mark.parametrize("is_init_key", ["some"]) def test_nested( self, device, @@ -232,7 +381,11 @@ def test_nested( device=device, ) for _td in collector: - assert _td[is_init_key].shape == _td[env.done_key].shape + for done_key in env.done_keys: + assert ( + _td[_replace_last(done_key, is_init_key)].shape + == _td[done_key].shape + ) break return diff --git a/test/test_helpers.py b/test/test_helpers.py index 51932185fea..1843a3f738f 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -136,6 +136,7 @@ def test_dqn_maker( expected_keys = [ "done", + "terminated", "action", "action_value", "step_count", @@ -212,6 +213,7 @@ def test_redq_make(device, from_pixels, gsde, exploration): actor(td) expected_keys = [ "done", + "terminated", "action", "sample_log_prob", "loc", @@ -247,6 +249,7 @@ def test_redq_make(device, from_pixels, gsde, exploration): qvalue(td) expected_keys = [ "done", + "terminated", "action", "sample_log_prob", "state_action_value", @@ -322,7 +325,9 @@ def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture "action", "belief", "done", + "terminated", ("next", "done"), + ("next", "terminated"), ("next", "reward"), ("next", "belief"), ("next", "encoded_latents"), @@ -346,7 +351,9 @@ def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture "action", "belief", "done", + "terminated", ("next", "done"), + ("next", "terminated"), ("next", "reward"), ("next", "belief"), ("next", "state"), @@ -475,7 +482,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial if keys: obs_spec = CompositeSpec( **{ - key: BoundedTensorSpec(maximum=1, minimum=1, shape=torch.Size([1])) + key: BoundedTensorSpec(high=1, low=1, shape=torch.Size([1])) for key in keys } ) @@ -483,7 +490,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial env = ContinuousActionVecMockEnv( device=device, observation_spec=obs_spec, - action_spec=BoundedTensorSpec(minimum=1, maximum=2, shape=torch.Size((1,))), + action_spec=BoundedTensorSpec(low=1, high=2, shape=torch.Size((1,))), ) env.out_key = "observation" else: diff --git a/test/test_libs.py b/test/test_libs.py index addd4def125..ae1218400ba 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2,6 +2,20 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib +from contextlib import nullcontext + +from torchrl.envs.transforms import ActionMask, TransformedEnv +from torchrl.modules import MaskedCategorical + +_has_isaac = importlib.util.find_spec("isaacgym") is not None + +if _has_isaac: + # isaac gym asks to be imported before torch... + import isaacgym # noqa + import isaacgymenvs # noqa + from torchrl.envs.libs.isaacgym import IsaacGymEnv + import argparse import importlib @@ -13,16 +27,24 @@ import pytest import torch -import torchrl from _utils_internal import ( _make_multithreaded_env, CARTPOLE_VERSIONED, get_available_devices, + get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED, + rand_reset, + rollout_consistency_assertion, ) from packaging import version +from tensordict import LazyStackedTensorDict +from tensordict.nn import ( + ProbabilisticTensorDictModule, + TensorDictModule, + TensorDictSequential, +) from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import implement_for @@ -37,8 +59,10 @@ ParallelEnv, RenameTransform, ) +from torchrl.envs.batched_envs import SerialEnv from torchrl.envs.libs.brax import _has_brax, BraxEnv from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper +from torchrl.envs.libs.envpool import _has_envpool, MultiThreadedEnvWrapper from torchrl.envs.libs.gym import ( _has_gym, _is_from_pixels, @@ -46,13 +70,16 @@ GymWrapper, MOGymEnv, MOGymWrapper, + set_gym_backend, ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.openml import OpenMLEnv +from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv +from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv +from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper -from torchrl.envs.utils import check_env_specs, ExplorationType -from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv +from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator _has_d4rl = importlib.util.find_spec("d4rl") is not None @@ -61,6 +88,7 @@ _has_sklearn = importlib.util.find_spec("sklearn") is not None +_has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None if _has_gym: try: @@ -88,6 +116,7 @@ if _has_vmas: import vmas + if _has_envpool: import envpool @@ -101,18 +130,18 @@ class TestGym: @pytest.mark.parametrize( "env_name", [ + HALFCHEETAH_VERSIONED, PONG_VERSIONED, # PENDULUM_VERSIONED, - HALFCHEETAH_VERSIONED, ], ) @pytest.mark.parametrize("frame_skip", [1, 3]) @pytest.mark.parametrize( "from_pixels,pixels_only", [ - [False, False], [True, True], [True, False], + [False, False], ], ) def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): @@ -125,6 +154,24 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): ): raise pytest.skip("no cuda device") + def non_null_obs(batched_td): + if from_pixels: + pix_norm = batched_td.get("pixels").flatten(-3, -1).float().norm(dim=-1) + pix_norm_next = ( + batched_td.get(("next", "pixels")) + .flatten(-3, -1) + .float() + .norm(dim=-1) + ) + idx = (pix_norm > 1) & (pix_norm_next > 1) + # eliminate batch size: all idx must be True (otherwise one could be filled with 0s) + while idx.ndim > 1: + idx = idx.all(0) + idx = idx.nonzero().squeeze(-1) + assert idx.numel(), "Did not find pixels with norm > 1" + return idx + return slice(None) + tdreset = [] tdrollout = [] final_seed = [] @@ -139,14 +186,22 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): np.random.seed(0) final_seed.append(env0.set_seed(0)) tdreset.append(env0.reset()) - tdrollout.append(env0.rollout(max_steps=50)) + rollout = env0.rollout(max_steps=50) + tdrollout.append(rollout) assert env0.from_pixels is from_pixels env0.close() env_type = type(env0._env) - del env0 assert_allclose_td(*tdreset, rtol=RTOL, atol=ATOL) - assert_allclose_td(*tdrollout, rtol=RTOL, atol=ATOL) + tdrollout = torch.stack(tdrollout, 0).contiguous() + + # custom filtering of non-null obs: mujoco rendering sometimes fails + # and renders black images. To counter this in the tests, we select + # tensordicts with all non-null observations + idx = non_null_obs(tdrollout) + assert_allclose_td( + tdrollout[0][..., idx], tdrollout[1][..., idx], rtol=RTOL, atol=ATOL + ) final_seed0, final_seed1 = final_seed assert final_seed0 == final_seed1 @@ -159,7 +214,15 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): if from_pixels and not _is_from_pixels(base_env): base_env = PixelObservationWrapper(base_env, pixels_only=pixels_only) assert type(base_env) is env_type + + # Compare GymEnv output with GymWrapper output env1 = GymWrapper(base_env, frame_skip=frame_skip) + assert env0.get_library_name(env0._env) == env1.get_library_name(env1._env) + # check that we didn't do more wrapping + assert type(env0._env) == type(env1._env) # noqa: E721 + assert env0.output_spec == env1.output_spec + assert env0.input_spec == env1.input_spec + del env0 torch.manual_seed(0) np.random.seed(0) final_seed2 = env1.set_seed(0) @@ -171,7 +234,12 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): assert_allclose_td(tdreset[0], tdreset2, rtol=RTOL, atol=ATOL) assert final_seed0 == final_seed2 - assert_allclose_td(tdrollout[0], rollout2, rtol=RTOL, atol=ATOL) + # same magic trick for mujoco as above + tdrollout = torch.stack([tdrollout[0], rollout2], 0).contiguous() + idx = non_null_obs(tdrollout) + assert_allclose_td( + tdrollout[0][..., idx], tdrollout[1][..., idx], rtol=RTOL, atol=ATOL + ) @pytest.mark.parametrize( "env_name", @@ -280,6 +348,304 @@ def info_reader(info, tensordict): env.rand_step() env.rollout(3) + @implement_for("gymnasium", "0.27.0", None) + def test_one_hot_and_categorical(self): + # tests that one-hot and categorical work ok when an integer is expected as action + cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True) + cliff_walking.rollout(10) + check_env_specs(cliff_walking) + + cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=False) + cliff_walking.rollout(10) + check_env_specs(cliff_walking) + + @implement_for("gym", None, "0.27.0") + def test_one_hot_and_categorical(self): # noqa: F811 + # we do not skip (bc we may want to make sure nothing is skipped) + # but CliffWalking-v0 in earlier Gym versions uses np.bool, which + # was deprecated after np 1.20, and we don't want to install multiple np + # versions. + return + + @implement_for("gymnasium", "0.27.0", None) + @pytest.mark.parametrize( + "envname", + ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + + (["FetchReach-v2"] if _has_gym_robotics else []), + ) + @pytest.mark.flaky(reruns=3, reruns_delay=1) + def test_vecenvs_wrapper(self, envname): + import gymnasium + + # we can't use parametrize with implement_for + env = GymWrapper( + gymnasium.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gymnasium.vector.AsyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + + @implement_for("gymnasium", "0.27.0", None) + # this env has Dict-based observation which is a nice thing to test + @pytest.mark.parametrize( + "envname", + ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + + (["FetchReach-v2"] if _has_gym_robotics else []), + ) + @pytest.mark.flaky(reruns=3, reruns_delay=1) + def test_vecenvs_env(self, envname): + from _utils_internal import rollout_consistency_assertion + + with set_gym_backend("gymnasium"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + + assert env.get_library_name(env._env) == "gymnasium" + # rollouts can be executed without decorator + check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) + + @implement_for("gym", "0.18", "0.27.0") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + @pytest.mark.flaky(reruns=3, reruns_delay=1) + def test_vecenvs_wrapper(self, envname): # noqa: F811 + import gym + + # we can't use parametrize with implement_for + for envname in ["CartPole-v1", "HalfCheetah-v4"]: + env = GymWrapper( + gym.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gym.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gym.vector.AsyncVectorEnv( + 2 * [lambda envname=envname: gym.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + + @implement_for("gym", "0.18", "0.27.0") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + @pytest.mark.flaky(reruns=3, reruns_delay=1) + def test_vecenvs_env(self, envname): # noqa: F811 + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + + assert env.get_library_name(env._env) == "gym" + # rollouts can be executed without decorator + check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) + if envname != "CartPole-v1": + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=True) + # rollouts can be executed without decorator + check_env_specs(env) + + @implement_for("gym", None, "0.18") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_wrapper(self, envname): # noqa: F811 + # skipping tests for older versions of gym + ... + + @implement_for("gym", None, "0.18") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_env(self, envname): # noqa: F811 + # skipping tests for older versions of gym + ... + + @implement_for("gym", None, "0.26") + @pytest.mark.parametrize("wrapper", [True, False]) + def test_gym_output_num(self, wrapper): + # gym has 4 outputs, no truncation + import gym + + if wrapper: + env = GymWrapper(gym.make(PENDULUM_VERSIONED)) + else: + with set_gym_backend("gym"): + env = GymEnv(PENDULUM_VERSIONED) + # truncated is read from the info + assert "truncated" in env.done_keys + assert "terminated" in env.done_keys + assert "done" in env.done_keys + check_env_specs(env) + + @implement_for("gym", "0.26", None) + @pytest.mark.parametrize("wrapper", [True, False]) + def test_gym_output_num(self, wrapper): # noqa: F811 + # gym has 5 outputs, with truncation + import gym + + if wrapper: + env = GymWrapper(gym.make(PENDULUM_VERSIONED)) + else: + with set_gym_backend("gym"): + env = GymEnv(PENDULUM_VERSIONED) + assert "truncated" in env.done_keys + assert "terminated" in env.done_keys + assert "done" in env.done_keys + check_env_specs(env) + + if wrapper: + # let's further test with a wrapper that exposes the env with old API + from gym.wrappers.compatibility import EnvCompatibility + + with pytest.raises( + ValueError, + match="GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility", + ): + GymWrapper(EnvCompatibility(gym.make("CartPole-v1"))) + + @implement_for("gymnasium", "0.27", None) + @pytest.mark.parametrize("wrapper", [True, False]) + def test_gym_output_num(self, wrapper): # noqa: F811 + # gym has 5 outputs, with truncation + import gymnasium as gym + + if wrapper: + env = GymWrapper(gym.make(PENDULUM_VERSIONED)) + else: + with set_gym_backend("gymnasium"): + env = GymEnv(PENDULUM_VERSIONED) + assert "truncated" in env.done_keys + assert "terminated" in env.done_keys + assert "done" in env.done_keys + check_env_specs(env) + + def test_gym_gymnasium_parallel(self): + # tests that both gym and gymnasium work with wrappers without + # decorating with set_gym_backend during execution + if importlib.util.find_spec("gym") is not None: + import gym + + old_api = version.parse(gym.__version__) < version.parse("0.26") + make_fun = EnvCreator(lambda: GymWrapper(gym.make(PENDULUM_VERSIONED))) + elif importlib.util.find_spec("gymnasium") is not None: + import gymnasium + + old_api = False + make_fun = EnvCreator( + lambda: GymWrapper(gymnasium.make(PENDULUM_VERSIONED)) + ) + else: + raise ImportError # unreachable under pytest.skipif + penv = ParallelEnv(2, make_fun) + rollout = penv.rollout(2) + if old_api: + assert "terminated" in rollout.keys() + # truncated is read from info + assert "truncated" in rollout.keys() + else: + assert "terminated" in rollout.keys() + assert "truncated" in rollout.keys() + check_env_specs(penv) + + @implement_for("gym", None, "0.22.0") + def test_vecenvs_nan(self): # noqa: F811 + # old versions of gym must return nan for next values when there is a done state + torch.manual_seed(0) + env = GymEnv("CartPole-v0", num_envs=2) + env.set_seed(0) + rollout = env.rollout(200) + assert torch.isfinite(rollout.get("observation")).all() + assert not torch.isfinite(rollout.get(("next", "observation"))).all() + env.close() + del env + + # same with collector + env = GymEnv("CartPole-v0", num_envs=2) + env.set_seed(0) + c = SyncDataCollector( + env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200 + ) + for rollout in c: + assert torch.isfinite(rollout.get("observation")).all() + assert not torch.isfinite(rollout.get(("next", "observation"))).all() + break + del c + return + + @implement_for("gym", "0.22.0", None) + def test_vecenvs_nan(self): # noqa: F811 + # new versions of gym must never return nan for next values when there is a done state + torch.manual_seed(0) + env = GymEnv("CartPole-v0", num_envs=2) + env.set_seed(0) + rollout = env.rollout(200) + assert torch.isfinite(rollout.get("observation")).all() + assert torch.isfinite(rollout.get(("next", "observation"))).all() + env.close() + del env + + # same with collector + env = GymEnv("CartPole-v0", num_envs=2) + env.set_seed(0) + c = SyncDataCollector( + env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200 + ) + for rollout in c: + assert torch.isfinite(rollout.get("observation")).all() + assert torch.isfinite(rollout.get(("next", "observation"))).all() + break + del c + return + + @implement_for("gymnasium", "0.27.0", None) + def test_vecenvs_nan(self): # noqa: F811 + # new versions of gym must never return nan for next values when there is a done state + torch.manual_seed(0) + env = GymEnv("CartPole-v0", num_envs=2) + env.set_seed(0) + rollout = env.rollout(200) + assert torch.isfinite(rollout.get("observation")).all() + assert torch.isfinite(rollout.get(("next", "observation"))).all() + env.close() + del env + + # same with collector + env = GymEnv("CartPole-v0", num_envs=2) + env.set_seed(0) + c = SyncDataCollector( + env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200 + ) + for rollout in c: + assert torch.isfinite(rollout.get("observation")).all() + assert torch.isfinite(rollout.get(("next", "observation"))).all() + break + del c + return + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 @@ -300,12 +666,7 @@ def _make_gym_environment(env_name): # noqa: F811 @pytest.mark.parametrize("env_name,task", [["cheetah", "run"]]) @pytest.mark.parametrize("frame_skip", [1, 3]) @pytest.mark.parametrize( - "from_pixels,pixels_only", - [ - [True, True], - [True, False], - [False, False], - ], + "from_pixels,pixels_only", [[True, True], [True, False], [False, False]] ) class TestDMControl: def test_dmcontrol(self, env_name, task, frame_skip, from_pixels, pixels_only): @@ -380,7 +741,7 @@ def test_dmcontrol(self, env_name, task, frame_skip, from_pixels, pixels_only): assert_allclose_td(rollout0, rollout2) def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): - if from_pixels and (not torch.has_cuda or not torch.cuda.device_count()): + if from_pixels and not torch.cuda.device_count(): raise pytest.skip("no cuda device") env = DMControlEnv( @@ -616,6 +977,11 @@ def test_jumanji_consistency(self, envname, batch_size): @pytest.mark.skipif(not _has_envpool, reason="No envpool library found") class TestEnvPool: + def test_lib(self): + import envpool + + assert MultiThreadedEnvWrapper.lib is envpool + @pytest.mark.parametrize("env_name", ENVPOOL_ALL_ENVS) def test_env_wrapper_creation(self, env_name): env_name = env_name.replace("ALE/", "") # EnvPool naming convention @@ -1029,7 +1395,7 @@ def make_brax(): @pytest.mark.skipif(not _has_vmas, reason="vmas not installed") class TestVmas: - @pytest.mark.parametrize("scenario_name", torchrl.envs.libs.vmas._get_envs()) + @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) @pytest.mark.parametrize("continuous_actions", [True, False]) def test_all_vmas_scenarios(self, scenario_name, continuous_actions): env = VmasEnv( @@ -1065,7 +1431,7 @@ def test_vmas_seeding(self, scenario_name): @pytest.mark.parametrize( "batch_size", [(), (12,), (12, 2), (12, 3), (12, 3, 1), (12, 3, 4)] ) - @pytest.mark.parametrize("scenario_name", torchrl.envs.libs.vmas._get_envs()) + @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) def test_vmas_batch_size_error(self, scenario_name, batch_size): num_envs = 12 n_agents = 2 @@ -1102,7 +1468,8 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size): @pytest.mark.parametrize("num_envs", [1, 20]) @pytest.mark.parametrize("n_agents", [1, 5]) @pytest.mark.parametrize( - "scenario_name", ["simple_reference", "waterfall", "flocking", "discovery"] + "scenario_name", + ["simple_reference", "simple_tag", "waterfall", "flocking", "discovery"], ) def test_vmas_batch_size(self, scenario_name, num_envs, n_agents): torch.manual_seed(0) @@ -1114,15 +1481,31 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents): ) env.set_seed(0) tdreset = env.reset() - tdrollout = env.rollout(max_steps=n_rollout_samples) + tdrollout = env.rollout( + max_steps=n_rollout_samples, + return_contiguous=False if env.het_specs else True, + ) env.close() + if env.het_specs: + assert isinstance(tdreset["agents"], LazyStackedTensorDict) + else: + assert isinstance(tdreset["agents"], TensorDict) + assert tdreset.batch_size == (num_envs,) - assert tdreset["agents", "observation"].shape[1] == env.n_agents + assert tdreset["agents"].batch_size == (num_envs, env.n_agents) + if not env.het_specs: + assert tdreset["agents", "observation"].shape[1] == env.n_agents assert tdreset["done"].shape[1] == 1 assert tdrollout.batch_size == (num_envs, n_rollout_samples) - assert tdrollout["agents", "observation"].shape[2] == env.n_agents + assert tdrollout["agents"].batch_size == ( + num_envs, + n_rollout_samples, + env.n_agents, + ) + if not env.het_specs: + assert tdrollout["agents", "observation"].shape[2] == env.n_agents assert tdrollout["next", "agents", "reward"].shape[2] == env.n_agents assert tdrollout["agents", "action"].shape[2] == env.n_agents assert tdrollout["done"].shape[2] == 1 @@ -1132,7 +1515,8 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents): @pytest.mark.parametrize("n_agents", [1, 5]) @pytest.mark.parametrize("continuous_actions", [True, False]) @pytest.mark.parametrize( - "scenario_name", ["simple_reference", "waterfall", "flocking", "discovery"] + "scenario_name", + ["simple_reference", "simple_tag", "waterfall", "flocking", "discovery"], ) def test_vmas_spec_rollout( self, scenario_name, num_envs, n_agents, continuous_actions @@ -1153,12 +1537,12 @@ def test_vmas_spec_rollout( ) for e in [env, wrapped]: e.set_seed(0) - check_env_specs(e) + check_env_specs(e, return_contiguous=False if e.het_specs else True) del e @pytest.mark.parametrize("num_envs", [1, 20]) @pytest.mark.parametrize("n_agents", [1, 5]) - @pytest.mark.parametrize("scenario_name", torchrl.envs.libs.vmas._get_envs()) + @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) def test_vmas_repr(self, scenario_name, num_envs, n_agents): if n_agents == 1 and scenario_name == "balance": return @@ -1244,17 +1628,15 @@ def make_vmas(): .all() ) - _reset = env.done_spec.rand() - while not _reset.any(): - _reset = env.done_spec.rand() - - tensordict = env.reset( - TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) + td_reset = TensorDict( + rand_reset(env), batch_size=env.batch_size, device=env.device ) - assert not tensordict["done"][_reset].all().item() + reset = td_reset["_reset"] + tensordict = env.reset(td_reset) + assert not tensordict["done"][reset].all().item() # vmas resets all the agent dimension if only one of the agents needs resetting # thus, here we check that where we did not reset any agent, all agents are still done - assert tensordict["done"].all(dim=2)[~_reset.any(dim=2)].all().item() + assert tensordict["done"].all(dim=2)[~reset.any(dim=2)].all().item() @pytest.mark.skipif(len(get_available_devices()) < 2, reason="not enough devices") @pytest.mark.parametrize("first", [0, 1]) @@ -1287,7 +1669,6 @@ def make_vmas(): @pytest.mark.parametrize("n_workers", [1, 2]) @pytest.mark.parametrize("n_agents", [1, 3]) def test_collector(self, n_envs, n_workers, n_agents, frames_per_batch=80): - torch.manual_seed(1) env_fun = lambda: VmasEnv( scenario="flocking", num_envs=n_envs, n_agents=n_agents, max_steps=7 @@ -1342,8 +1723,45 @@ def test_collector(self, n_envs, n_workers, n_agents, frames_per_batch=80): n_observations_per_agent, ) assert _td["next", env.reward_key].shape == agents_td_batch + (1,) - assert _td[env.done_key].shape == td_batch + (1,) - assert _td["next", env.done_key].shape == td_batch + (1,) + for done_key in env.done_keys: + assert _td[done_key].shape == td_batch + (1,) + assert _td["next", done_key].shape == td_batch + (1,) + + assert env.reward_key not in _td.keys(True, True) + assert env.action_key not in _td["next"].keys(True, True) + + def test_collector_heterogeneous(self, n_envs=10, frames_per_batch=20): + env = VmasEnv( + scenario="simple_tag", + num_envs=n_envs, + ) + torch.manual_seed(1) + + ccollector = SyncDataCollector( + create_env_fn=env, + policy=None, + frames_per_batch=frames_per_batch, + total_frames=1000, + device="cpu", + ) + + for i, _td in enumerate(ccollector): + if i == 1: + break + ccollector.shutdown() + + td_batch = (n_envs, frames_per_batch // n_envs) + agents_td_batch = td_batch + (env.n_agents,) + + assert _td.shape == td_batch + assert _td["next"].shape == td_batch + assert _td["agents"].shape == agents_td_batch + assert _td["next", "agents"].shape == agents_td_batch + assert _td["collector"].shape == td_batch + assert _td["next", env.reward_key].shape == agents_td_batch + (1,) + for done_key in env.done_keys: + assert _td[done_key].shape == td_batch + (1,) + assert _td["next", done_key].shape == td_batch + (1,) assert env.reward_key not in _td.keys(True, True) assert env.action_key not in _td["next"].keys(True, True) @@ -1352,36 +1770,88 @@ def test_collector(self, n_envs, n_workers, n_agents, frames_per_batch=80): @pytest.mark.skipif(not _has_d4rl, reason="D4RL not found") class TestD4RL: @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) - def test_terminate_on_end(self, task): - t0 = time.time() - data_true = D4RLExperienceReplay( + @pytest.mark.parametrize("use_truncated_as_done", [True, False]) + @pytest.mark.parametrize("split_trajs", [True, False]) + def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): + + with pytest.warns( + UserWarning, match="Using use_truncated_as_done=True" + ) if use_truncated_as_done else nullcontext(): + data_true = D4RLExperienceReplay( + task, + split_trajs=split_trajs, + from_env=False, + terminate_on_end=True, + batch_size=2, + use_truncated_as_done=use_truncated_as_done, + ) + _ = D4RLExperienceReplay( task, - split_trajs=True, + split_trajs=split_trajs, from_env=False, - terminate_on_end=True, + terminate_on_end=False, batch_size=2, - use_timeout_as_done=False, + use_truncated_as_done=use_truncated_as_done, ) - _ = D4RLExperienceReplay( + data_from_env = D4RLExperienceReplay( task, - split_trajs=True, + split_trajs=split_trajs, + from_env=True, + batch_size=2, + use_truncated_as_done=use_truncated_as_done, + ) + if not use_truncated_as_done: + keys = set(data_from_env._storage._storage.keys(True, True)) + keys = keys.intersection(data_true._storage._storage.keys(True, True)) + assert ( + data_true._storage._storage.shape + == data_from_env._storage._storage.shape + ) + assert_allclose_td( + data_true._storage._storage.select(*keys), + data_from_env._storage._storage.select(*keys), + ) + else: + leaf_names = data_from_env._storage._storage.keys(True) + leaf_names = [ + name[-1] if isinstance(name, tuple) else name for name in leaf_names + ] + assert "truncated" in leaf_names + leaf_names = data_true._storage._storage.keys(True) + leaf_names = [ + name[-1] if isinstance(name, tuple) else name for name in leaf_names + ] + assert "truncated" not in leaf_names + + @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) + def test_direct_download(self, task): + data_direct = D4RLExperienceReplay( + task, + split_trajs=False, from_env=False, - terminate_on_end=False, batch_size=2, - use_timeout_as_done=False, + use_truncated_as_done=True, + direct_download=True, ) - data_from_env = D4RLExperienceReplay( + data_d4rl = D4RLExperienceReplay( task, - split_trajs=True, + split_trajs=False, from_env=True, batch_size=2, - use_timeout_as_done=False, + use_truncated_as_done=True, + direct_download=False, + terminate_on_end=True, # keep the last time step ) - keys = set(data_from_env._storage._storage.keys(True, True)) - keys = keys.intersection(data_true._storage._storage.keys(True, True)) + keys = set(data_direct._storage._storage.keys(True, True)) + keys = keys.intersection(data_d4rl._storage._storage.keys(True, True)) + assert len(keys) assert_allclose_td( - data_true._storage._storage.select(*keys), - data_from_env._storage._storage.select(*keys), + data_direct._storage._storage.select(*keys).apply( + lambda t: t.as_tensor().float() + ), + data_d4rl._storage._storage.select(*keys).apply( + lambda t: t.as_tensor().float() + ), ) @pytest.mark.parametrize( @@ -1402,7 +1872,7 @@ def test_terminate_on_end(self, task): def test_d4rl_dummy(self, task): t0 = time.time() _ = D4RLExperienceReplay(task, split_trajs=True, from_env=True, batch_size=2) - print(f"completed test after {time.time()-t0}s") + print(f"terminated test after {time.time()-t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -1416,11 +1886,14 @@ def test_dataset_build(self, task, split_trajs, from_env): env = GymWrapper(gym.make(task)) rollout = env.rollout(2) for key in rollout.keys(True, True): + if "truncated" in key: + # truncated is missing from static datasets + continue sim = rollout[key] offline = sample[key] assert sim.dtype == offline.dtype, key assert sim.shape[-1] == offline.shape[-1], key - print(f"completed test after {time.time()-t0}s") + print(f"terminated test after {time.time()-t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -1439,7 +1912,7 @@ def test_d4rl_iteration(self, task, split_trajs): for sample in data: # noqa: B007 i += 1 assert len(data) // i == batch_size - print(f"completed test after {time.time()-t0}s") + print(f"terminated test after {time.time()-t0}s") @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @@ -1481,6 +1954,415 @@ def test_data(self, dataset): assert len(data) // 2048 in (i, i - 1) +@pytest.mark.skipif(not _has_isaac, reason="IsaacGym not found") +@pytest.mark.parametrize( + "task", + [ + "AllegroHand", + # "AllegroKuka", + # "AllegroKukaTwoArms", + # "AllegroHandManualDR", + # "AllegroHandADR", + "Ant", + # "Anymal", + # "AnymalTerrain", + # "BallBalance", + # "Cartpole", + # "FactoryTaskGears", + # "FactoryTaskInsertion", + # "FactoryTaskNutBoltPick", + # "FactoryTaskNutBoltPlace", + # "FactoryTaskNutBoltScrew", + # "FrankaCabinet", + # "FrankaCubeStack", + "Humanoid", + # "HumanoidAMP", + # "Ingenuity", + # "Quadcopter", + # "ShadowHand", + "Trifinger", + ], +) +@pytest.mark.parametrize("num_envs", [10, 20]) +@pytest.mark.parametrize("device", get_default_devices()) +class TestIsaacGym: + @classmethod + def _run_on_proc(cls, q, task, num_envs, device): + try: + env = IsaacGymEnv(task=task, num_envs=num_envs, device=device) + check_env_specs(env) + q.put(("succeeded!", None)) + except Exception as err: + q.put(("failed!", err)) + raise err + + def test_env(self, task, num_envs, device): + from torch import multiprocessing as mp + + q = mp.Queue(1) + proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device)) + try: + proc.start() + msg, error = q.get() + if msg != "succeeded!": + raise error + finally: + q.close() + proc.join() + + # + # def test_collector(self, task, num_envs, device): + # env = IsaacGymEnv(task=task, num_envs=num_envs, device=device) + # collector = SyncDataCollector( + # env, + # policy=SafeModule(nn.LazyLinear(out_features=env.observation_spec['obs'].shape[-1]), in_keys=["obs"], out_keys=["action"]), + # frames_per_batch=20, + # total_frames=-1 + # ) + # for c in collector: + # assert c.shape == torch.Size([num_envs, 20]) + # break + + +@pytest.mark.skipif(not _has_pettingzoo, reason="PettingZoo not found") +class TestPettingZoo: + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("continuous_actions", [True, False]) + @pytest.mark.parametrize("use_mask", [True]) + @pytest.mark.parametrize("return_state", [True, False]) + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ALL_IN_ONE_GROUP, MarlGroupMapType.ONE_GROUP_PER_AGENT], + ) + def test_pistonball( + self, parallel, continuous_actions, use_mask, return_state, group_map + ): + + kwargs = {"n_pistons": 21, "continuous": continuous_actions} + + env = PettingZooEnv( + task="pistonball_v6", + parallel=parallel, + seed=0, + return_state=return_state, + use_mask=use_mask, + group_map=group_map, + **kwargs, + ) + + check_env_specs(env) + + @pytest.mark.parametrize( + "wins_player_0", + [True, False], + ) + def test_tic_tac_toe(self, wins_player_0): + env = PettingZooEnv( + task="tictactoe_v3", + parallel=False, + group_map={"player": ["player_1", "player_2"]}, + categorical_actions=False, + seed=0, + use_mask=True, + ) + + class Policy: + + action = 0 + t = 0 + + def __call__(self, td): + new_td = env.input_spec["full_action_spec"].zero() + + player_acting = 0 if self.t % 2 == 0 else 1 + other_player = 1 if self.t % 2 == 0 else 0 + # The acting player has "mask" True and "action_mask" set to the available actions + assert td["player", "mask"][player_acting].all() + assert td["player", "action_mask"][player_acting].any() + # The non-acting player has "mask" False and "action_mask" set to all Trues + assert not td["player", "mask"][other_player].any() + assert td["player", "action_mask"][other_player].all() + + if self.t % 2 == 0: + if not wins_player_0 and self.t == 4: + new_td["player", "action"][0][self.action + 1] = 1 + else: + new_td["player", "action"][0][self.action] = 1 + else: + new_td["player", "action"][1][self.action + 6] = 1 + if td["player", "mask"][1].all(): + self.action += 1 + self.t += 1 + return td.update(new_td) + + td = env.rollout(100, policy=Policy()) + + assert td.batch_size[0] == (5 if wins_player_0 else 6) + assert (td[:-1]["next", "player", "reward"] == 0).all() + if wins_player_0: + assert ( + td[-1]["next", "player", "reward"] == torch.tensor([[1], [-1]]) + ).all() + else: + assert ( + td[-1]["next", "player", "reward"] == torch.tensor([[-1], [1]]) + ).all() + + @pytest.mark.parametrize( + "task", + [ + "multiwalker_v9", + "waterworld_v4", + "pursuit_v4", + "simple_spread_v3", + "simple_v3", + "rps_v2", + "cooperative_pong_v5", + "pistonball_v6", + ], + ) + def test_envs_one_group_parallel(self, task): + env = PettingZooEnv( + task=task, + parallel=True, + seed=0, + use_mask=False, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize( + "task", + [ + "multiwalker_v9", + "waterworld_v4", + "pursuit_v4", + "simple_spread_v3", + "simple_v3", + "rps_v2", + "cooperative_pong_v5", + "pistonball_v6", + "connect_four_v3", + "tictactoe_v3", + "chess_v6", + "gin_rummy_v4", + "tictactoe_v3", + ], + ) + def test_envs_one_group_aec(self, task): + env = PettingZooEnv( + task=task, + parallel=False, + seed=0, + use_mask=True, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize( + "task", + [ + "simple_adversary_v3", + "simple_crypto_v3", + "simple_push_v3", + "simple_reference_v3", + "simple_speaker_listener_v4", + "simple_tag_v3", + "simple_world_comm_v3", + "knights_archers_zombies_v10", + "basketball_pong_v3", + "boxing_v2", + "foozpong_v3", + ], + ) + def test_envs_more_groups_parallel(self, task): + env = PettingZooEnv( + task=task, + parallel=True, + seed=0, + use_mask=False, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize( + "task", + [ + "simple_adversary_v3", + "simple_crypto_v3", + "simple_push_v3", + "simple_reference_v3", + "simple_speaker_listener_v4", + "simple_tag_v3", + "simple_world_comm_v3", + "knights_archers_zombies_v10", + "basketball_pong_v3", + "boxing_v2", + "foozpong_v3", + "go_v5", + ], + ) + def test_envs_more_groups_aec(self, task): + env = PettingZooEnv( + task=task, + parallel=False, + seed=0, + use_mask=True, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize("task", ["knights_archers_zombies_v10", "pistonball_v6"]) + @pytest.mark.parametrize("parallel", [True, False]) + def test_vec_env(self, task, parallel): + env_fun = lambda: PettingZooEnv( + task=task, + parallel=parallel, + seed=0, + use_mask=not parallel, + ) + vec_env = ParallelEnv(2, create_env_fn=env_fun) + vec_env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize("task", ["knights_archers_zombies_v10", "pistonball_v6"]) + @pytest.mark.parametrize("parallel", [True, False]) + def test_collector(self, task, parallel): + env_fun = lambda: PettingZooEnv( + task=task, + parallel=parallel, + seed=0, + use_mask=not parallel, + ) + coll = SyncDataCollector( + create_env_fn=env_fun, frames_per_batch=30, total_frames=60, policy=None + ) + for _ in coll: + break + + +@pytest.mark.skipif(not _has_robohive, reason="SMACv2 not found") +class TestRoboHive: + # unfortunately we must import robohive to get the available envs + # and this import will occur whenever pytest is run on this file. + # The other option would be not to use parametrize but that also + # means less informative error trace stacks. + # In the CI, robohive should not coexist with other libs so that's fine. + # Locally these imports can be annoying, especially given the amount of + # stuff printed by robohive. + @pytest.mark.parametrize("from_pixels", [True, False]) + @set_gym_backend("gym") + def test_robohive(self, from_pixels): + for envname in RoboHiveEnv.available_envs: + try: + if any( + substr in envname + for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") + ): + print("not testing envs with prebuilt rendering") + return + if "Adroit" in envname: + print("tcdm are broken") + return + try: + env = RoboHiveEnv(envname) + except AttributeError as err: + if "'MjData' object has no attribute 'get_body_xipos'" in str(err): + print("tcdm are broken") + return + else: + raise err + if ( + from_pixels + and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 + ): + print("no camera") + return + check_env_specs(env) + except Exception as err: + raise RuntimeError(f"Test with robohive end {envname} failed.") from err + + +@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") +class TestSmacv2: + def test_env_procedural(self): + distribution_config = { + "n_units": 5, + "n_enemies": 6, + "team_gen": { + "dist_type": "weighted_teams", + "unit_types": ["marine", "marauder", "medivac"], + "exception_unit_types": ["medivac"], + "weights": [0.5, 0.2, 0.3], + "observe": True, + }, + "start_positions": { + "dist_type": "surrounded_and_reflect", + "p": 0.5, + "n_enemies": 5, + "map_x": 32, + "map_y": 32, + }, + } + env = SMACv2Env( + map_name="10gen_terran", + capability_config=distribution_config, + seed=0, + ) + check_env_specs(env, seed=None) + env.close() + + @pytest.mark.parametrize("categorical_actions", [True, False]) + @pytest.mark.parametrize("map", ["MMM2", "3s_vs_5z"]) + def test_env(self, map: str, categorical_actions): + env = SMACv2Env( + map_name=map, + categorical_actions=categorical_actions, + seed=0, + ) + check_env_specs(env, seed=None) + env.close() + + def test_parallel_env(self): + env = TransformedEnv( + ParallelEnv( + num_workers=2, + create_env_fn=lambda: SMACv2Env( + map_name="3s_vs_5z", + seed=0, + ), + ), + ActionMask( + action_key=("agents", "action"), mask_key=("agents", "action_mask") + ), + ) + check_env_specs(env, seed=None) + env.close() + + def test_collector(self): + env = SMACv2Env(map_name="MMM2", seed=0, categorical_actions=True) + in_feats = env.observation_spec["agents", "observation"].shape[-1] + out_feats = env.action_spec.space.n + + module = TensorDictModule( + nn.Linear(in_feats, out_feats), + in_keys=[("agents", "observation")], + out_keys=[("agents", "logits")], + ) + prob = ProbabilisticTensorDictModule( + in_keys={"logits": ("agents", "logits"), "mask": ("agents", "action_mask")}, + out_keys=[("agents", "action")], + distribution_class=MaskedCategorical, + ) + actor = TensorDictSequential(module, prob) + + collector = SyncDataCollector( + env, policy=actor, frames_per_batch=20, total_frames=40 + ) + for _ in collector: + break + collector.shutdown() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_loggers.py b/test/test_loggers.py index a4937dd0fc3..a19c8251b28 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -36,6 +36,19 @@ def tb_logger(tmp_path_factory): del logger +@pytest.fixture +def config(): + return { + "value": "value", + "nested": {"inner": 3, "value": "value"}, + "int": 3, + "list": [3, 4, 5], + "tuple": (2,), + "float": 3.45, + "bool": True, + } + + @pytest.mark.skipif(not _has_tb, reason="TensorBoard not installed") class TestTensorboard: @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) @@ -98,6 +111,12 @@ def test_log_video(self, steps, tb_logger): step=steps[i] if steps else None, ) + def test_log_hparams(self, tb_logger, config): + del config["nested"] # not supported in tensorboard + del config["list"] # not supported in tensorboard + del config["tuple"] # not supported in tensorboard + tb_logger.log_hparams(config) + def test_log_histogram(self, tb_logger): torch.manual_seed(0) # test with torch @@ -110,67 +129,63 @@ def test_log_histogram(self, tb_logger): class TestCSVLogger: @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) - def test_log_scalar(self, steps): + def test_log_scalar(self, steps, tmpdir): torch.manual_seed(0) - with tempfile.TemporaryDirectory() as log_dir: - exp_name = "ramala" - logger = CSVLogger(log_dir=log_dir, exp_name=exp_name) + exp_name = "ramala" + logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name) - values = torch.rand(3) - for i in range(3): - scalar_name = "foo" - scalar_value = values[i].item() - logger.log_scalar( - value=scalar_value, - name=scalar_name, - step=steps[i] if steps else None, - ) + values = torch.rand(3) + for i in range(3): + scalar_name = "foo" + scalar_value = values[i].item() + logger.log_scalar( + value=scalar_value, + name=scalar_name, + step=steps[i] if steps else None, + ) - with open( - os.path.join(log_dir, exp_name, "scalars", "foo.csv"), "r" - ) as file: - for i, row in enumerate(file.readlines()): - step = steps[i] if steps else i - assert row == f"{step},{values[i].item()}\n" + with open(os.path.join(tmpdir, exp_name, "scalars", "foo.csv"), "r") as file: + for i, row in enumerate(file.readlines()): + step = steps[i] if steps else i + assert row == f"{step},{values[i].item()}\n" @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) - def test_log_video(self, steps): + def test_log_video(self, steps, tmpdir): torch.manual_seed(0) - with tempfile.TemporaryDirectory() as log_dir: - exp_name = "ramala" - logger = CSVLogger(log_dir=log_dir, exp_name=exp_name) + exp_name = "ramala" + logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name) - # creating a sample video (T, C, H, W), where T - number of frames, - # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. - # the first 64 frames are black and the next 64 are white - video = torch.cat( - (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + # creating a sample video (T, C, H, W), where T - number of frames, + # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. + # the first 64 frames are black and the next 64 are white + video = torch.cat( + (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + ) + video = video[None, :] + for i in range(3): + logger.log_video( + name="foo", + video=video, + step=steps[i] if steps else None, ) - video = video[None, :] - for i in range(3): - logger.log_video( - name="foo", - video=video, - step=steps[i] if steps else None, - ) - sleep(0.01) # wait until events are registered + sleep(0.01) # wait until events are registered + + # check that the logged videos are the same as the initial video + video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + ".pt" + logged_video = torch.load( + os.path.join(tmpdir, exp_name, "videos", video_file_name) + ) + assert torch.equal(video, logged_video), logged_video - # check that the logged videos are the same as the initial video - video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + ".pt" - logged_video = torch.load( - os.path.join(log_dir, exp_name, "videos", video_file_name) + # check that we catch the error in case the format of the tensor is wrong + video_wrong_format = torch.zeros(64, 2, 32, 32) + video_wrong_format = video_wrong_format[None, :] + with pytest.raises(Exception): + logger.log_video( + name="foo", + video=video_wrong_format, + step=steps[i] if steps else None, ) - assert torch.equal(video, logged_video), logged_video - - # check that we catch the error in case the format of the tensor is wrong - video_wrong_format = torch.zeros(64, 2, 32, 32) - video_wrong_format = video_wrong_format[None, :] - with pytest.raises(Exception): - logger.log_video( - name="foo", - video=video_wrong_format, - step=steps[i] if steps else None, - ) def test_log_histogram(self): torch.manual_seed(0) @@ -181,6 +196,18 @@ def test_log_histogram(self): data = torch.randn(10) logger.log_histogram("hist", data, step=0, bins=2) + def test_log_config(self, tmpdir, config): + torch.manual_seed(0) + + exp_name = "ramala" + logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name) + logger.log_hparams(cfg=config) + + with open(os.path.join(tmpdir, exp_name, "texts", "hparams0.txt"), "r") as file: + txt = "\n".join([f"{k}: {val}" for k, val in sorted(config.items())]) + text = "".join(file.readlines()) + assert text == txt + @pytest.fixture(scope="class") def wandb_logger(tmp_path_factory): @@ -247,6 +274,13 @@ def test_log_video(self, wandb_logger): video=video_wrong_format, ) + def test_log_hparams(self, wandb_logger, config): + wandb_logger.log_hparams(config) + for key, value in config.items(): + if isinstance(value, tuple): + value = list(value) # wandb converts tuples to lists + assert wandb_logger.experiment.config[key] == value + def test_log_histogram(self, wandb_logger): torch.manual_seed(0) # test with torch @@ -330,6 +364,10 @@ def test_log_histogram(self, mlflow_fixture): data = torch.randn(10) logger.log_histogram("hist", data, step=0, bins=2) + def test_log_hparams(self, mlflow_fixture, config): + logger, client = mlflow_fixture + logger.log_hparams(config) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/test/test_modules.py b/test/test_modules.py index caa4cca1c9b..ee1884c5573 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -16,8 +16,11 @@ from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec from torchrl.modules import ( CEMPlanner, + DTActor, LSTMNet, + MultiAgentConvNet, MultiAgentMLP, + OnlineDTActor, QMixer, SafeModule, TanhModule, @@ -25,7 +28,11 @@ VDNMixer, ) from torchrl.modules.distributions.utils import safeatanh, safetanh -from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models.decision_transformer import ( + _has_transformers, + DecisionTransformer, +) from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -175,6 +182,113 @@ def test_convnet( assert y.shape == torch.Size([*batch, expected_features]) +class TestConv3d: + @pytest.mark.parametrize("in_features", [3, 10, None]) + @pytest.mark.parametrize( + "input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features", + [ + (10, None, None, 3, 1, 0, 32 * 4 * 4 * 4), + (10, 3, 32, 3, 1, 1, 32 * 10 * 10 * 10), + ], + ) + @pytest.mark.parametrize( + "activation_class, activation_kwargs", + [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], + ) + @pytest.mark.parametrize( + "norm_class, norm_kwargs", + [ + (None, None), + (nn.LazyBatchNorm3d, {}), + (nn.BatchNorm3d, {"num_features": 32}), + ], + ) + @pytest.mark.parametrize("bias_last_layer", [True, False]) + @pytest.mark.parametrize( + "aggregator_class, aggregator_kwargs", + [(SquashDims, None)], + ) + @pytest.mark.parametrize("squeeze_output", [False]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("batch", [(2,), (2, 2)]) + def test_conv3dnet( + self, + batch, + in_features, + depth, + num_cells, + kernel_sizes, + strides, + paddings, + activation_class, + activation_kwargs, + norm_class, + norm_kwargs, + bias_last_layer, + aggregator_class, + aggregator_kwargs, + squeeze_output, + device, + input_size, + expected_features, + seed=0, + ): + torch.manual_seed(seed) + conv3dnet = Conv3dNet( + in_features=in_features, + depth=depth, + num_cells=num_cells, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + activation_class=activation_class, + activation_kwargs=activation_kwargs, + norm_class=norm_class, + norm_kwargs=norm_kwargs, + bias_last_layer=bias_last_layer, + aggregator_class=aggregator_class, + aggregator_kwargs=aggregator_kwargs, + squeeze_output=squeeze_output, + device=device, + ) + if in_features is None: + in_features = 5 + x = torch.randn( + *batch, in_features, input_size, input_size, input_size, device=device + ) + y = conv3dnet(x) + assert y.shape == torch.Size([*batch, expected_features]) + with pytest.raises(ValueError, match="must have at least 4 dimensions"): + conv3dnet(torch.randn(3, 16, 16)) + + def test_errors(self): + with pytest.raises( + ValueError, match="Null depth is not permitted with Conv3dNet" + ): + conv3dnet = Conv3dNet( + in_features=5, + num_cells=32, + depth=0, + ) + with pytest.raises( + ValueError, match="depth=None requires one of the input args" + ): + conv3dnet = Conv3dNet( + in_features=5, + num_cells=32, + depth=None, + ) + with pytest.raises( + ValueError, match="consider matching or specifying a constant num_cells" + ): + conv3dnet = Conv3dNet( + in_features=5, + num_cells=[32], + depth=None, + kernel_sizes=[3, 3], + ) + + @pytest.mark.parametrize( "layer_class", [ @@ -440,7 +554,7 @@ def test_dreamer_decoder( @pytest.mark.parametrize("action_size", [3, 6]) def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size): action_spec = BoundedTensorSpec( - shape=(action_size,), dtype=torch.float32, minimum=-1, maximum=1 + shape=(action_size,), dtype=torch.float32, low=-1, high=1 ) rssm_prior = RSSMPrior( action_spec, @@ -495,7 +609,7 @@ def test_rssm_rollout( self, device, batch_size, temporal_size, stoch_size, deter_size, action_size ): action_spec = BoundedTensorSpec( - shape=(action_size,), dtype=torch.float32, minimum=-1, maximum=1 + shape=(action_size,), dtype=torch.float32, low=-1, high=1 ) rssm_prior = RSSMPrior( action_spec, @@ -803,6 +917,58 @@ def test_mlp( # same input different output assert not torch.allclose(out[..., i, :], out[..., j, :]) + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralised", [True, False]) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + def test_cnn( + self, n_agents, centralised, share_params, batch, x=50, y=50, channels=3 + ): + torch.manual_seed(0) + cnn = MultiAgentConvNet( + n_agents=n_agents, centralised=centralised, share_params=share_params + ) + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.randn(*batch, n_agents, channels, x, y)}, + [*batch, n_agents], + ) + }, + batch_size=batch, + ) + obs = td[("agents", "observation")] + out = cnn(obs) + assert out.shape[:-1] == (*batch, n_agents) + for i in range(n_agents): + if centralised and share_params: + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + obs[..., 0, 0, 0, 0] += 1 + out2 = cnn(obs) + for i in range(n_agents): + if centralised: + # a modification to the input of agent 0 will impact all agents + assert not torch.allclose(out[..., i, :], out2[..., i, :]) + elif i > 0: + assert torch.allclose(out[..., i, :], out2[..., i, :]) + + obs = torch.randn(*batch, 1, channels, x, y).expand( + *batch, n_agents, channels, x, y + ) + out = cnn(obs) + for i in range(n_agents): + if share_params: + # same input same output + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + # same input different output + assert not torch.allclose(out[..., i, :], out[..., j, :]) + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize( "batch", @@ -952,6 +1118,74 @@ def test_tanh_atanh(use_vmap, scale): torch.testing.assert_close(x.grad, torch.ones_like(x)) +@pytest.mark.skipif( + not _has_transformers, reason="transformers needed for TestDecisionTransformer" +) +class TestDecisionTransformer: + def test_init(self): + DecisionTransformer( + 3, + 4, + ) + with pytest.raises(TypeError): + DecisionTransformer(3, 4, config="some_str") + DecisionTransformer( + 3, + 4, + config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_exec(self, batch_dims, T=5): + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + model = DecisionTransformer( + 3, + 4, + config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + out = model(observations, actions, r2go) + assert out.shape == torch.Size([*batch_dims, T, 16]) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_dtactor(self, batch_dims, T=5): + dtactor = DTActor( + 3, + 4, + transformer_config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + out = dtactor(observations, actions, r2go) + assert out.shape == torch.Size([*batch_dims, T, 4]) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_onlinedtactor(self, batch_dims, T=5): + dtactor = OnlineDTActor( + 3, + 4, + transformer_config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + mu, sig = dtactor(observations, actions, r2go) + assert mu.shape == torch.Size([*batch_dims, T, 4]) + assert sig.shape == torch.Size([*batch_dims, T, 4]) + assert (dtactor.log_std_min < sig.log()).all() + assert (dtactor.log_std_max > sig.log()).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_rb.py b/test/test_rb.py index 7db3c386beb..36158d8a69e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -5,6 +5,7 @@ import argparse import importlib +import pickle import sys from functools import partial from unittest import mock @@ -242,6 +243,17 @@ def test_index(self, rb_type, sampler, writer, storage, size): b = b.all() assert b + def test_pickable(self, rb_type, sampler, writer, storage, size): + + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + serialized = pickle.dumps(rb) + rb2 = pickle.loads(serialized) + assert rb.__dict__.keys() == rb2.__dict__.keys() + for key in sorted(rb.__dict__.keys()): + assert isinstance(rb.__dict__[key], type(rb2.__dict__[key])) + @pytest.mark.parametrize("storage_type", [TensorStorage]) class TestStorages: diff --git a/test/test_rlhf.py b/test/test_rlhf.py index bdb40bf3747..2abb9a6d386 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -165,7 +165,7 @@ def test_preproc_data( pre_tokenization_hook=pre_tokenization_hook, from_disk=True, root_dir=tmpdir1, - valid_size=500, + valid_size=20, ) dataset = loader._load_dataset() assert isinstance(dataset, datasets.Dataset) @@ -266,7 +266,7 @@ def test_tensordict_tokenizer( from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer.pad_token = 100 + tokenizer.pad_token = "-pad-" process = TensorDictTokenizer( tokenizer, max_length=max_length, @@ -313,7 +313,7 @@ def test_prompt_tensordict_tokenizer( from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer.pad_token = 100 + tokenizer.pad_token = "-pad-" process = PromptTensorDictTokenizer( tokenizer, max_length=max_length, @@ -453,7 +453,10 @@ def _reward_model(self): def _get_rollout_model(self, max_new_tokens=10): return RolloutFromModel( - self._model, self._ref_model, self._reward_model, max_new_tokens + model=self._model, + ref_model=self._ref_model, + reward_model=self._reward_model, + max_new_tokens=max_new_tokens, ) def test_padded_right_to_left(self): @@ -533,6 +536,7 @@ def test_rollout_from_data(self, tldr_batch_dir, max_new_tokens=10): expected_keys = { ("next", "attention_mask"), ("next", "done"), + ("next", "terminated"), ("next", "input_ids"), ("next", "reward"), "action", diff --git a/test/test_specs.py b/test/test_specs.py index 10adac74bdc..1f1dbb8b8aa 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -11,6 +11,8 @@ from _utils_internal import get_available_devices, get_default_devices, set_global_var from scipy.stats import chisquare from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.utils import _unravel_key_to_tuple + from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, BinaryDiscreteTensorSpec, @@ -21,9 +23,11 @@ MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, + TensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) +from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) @@ -85,17 +89,7 @@ def test_unbounded(dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) +@pytest.mark.parametrize("shape", [[], torch.Size([3])]) def test_ndbounded(dtype, shape): torch.manual_seed(0) np.random.seed(0) @@ -118,7 +112,14 @@ def test_ndbounded(dtype, shape): assert ts.is_in(r) ts.encode(lb + torch.rand(10) * (ub - lb)) ts.encode((lb + torch.rand(10) * (ub - lb)).numpy()) - assert (ts.encode(ts.to_numpy(r)) == r).all() + + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + with pytest.raises(AssertionError), set_global_var( torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True ): @@ -168,7 +169,12 @@ def test_ndunbounded(dtype, n, shape): ts.to_numpy(r) assert ts.is_in(r) assert r.dtype is dtype - assert (ts.encode(ts.to_numpy(r)) == r).all() + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() @pytest.mark.parametrize("n", range(3, 10)) @@ -198,8 +204,12 @@ def test_binary(n, shape): ) assert ts.is_in(r) assert ((r == 0) | (r == 1)).all() - assert (ts.encode(r.numpy()) == r).all() - assert (ts.encode(ts.to_numpy(r)) == r).all() + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() @pytest.mark.parametrize( @@ -243,7 +253,13 @@ def test_mult_onehot(shape, ns): assert _r.shape[-1] == _n categorical = ts.to_categorical(r) assert not ts.is_in(categorical) - assert (ts.encode(categorical) == r).all() + # assert (ts.encode(categorical) == r).all() + if not shape: + assert (ts.encode(categorical) == r).all() + else: + with pytest.raises(RuntimeError, match="is invalid for input of size"): + ts.encode(categorical) + assert (ts.expand(*shape, *ts.shape).encode(categorical) == r).all() @pytest.mark.parametrize( @@ -256,15 +272,7 @@ def test_mult_onehot(shape, ns): [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], ], ) -@pytest.mark.parametrize( - "shape", - [ - None, - [], - torch.Size([3]), - torch.Size([4, 5]), - ], -) +@pytest.mark.parametrize("shape", [None, [], torch.Size([3]), torch.Size([4, 5])]) @pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) def test_multi_discrete(shape, ns, dtype): torch.manual_seed(0) @@ -301,27 +309,9 @@ def test_multi_discrete(shape, ns, dtype): assert not ts.is_in(projection) -@pytest.mark.parametrize( - "n", - [ - 1, - 4, - 7, - 99, - ], -) +@pytest.mark.parametrize("n", [1, 4, 7, 99]) @pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize( - "shape", - [ - None, - [], - [ - 1, - ], - [1, 2], - ], -) +@pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) def test_discrete_conversion(n, device, shape): categorical = DiscreteTensorSpec(n, device=device, shape=shape) shape_one_hot = [n] if not shape else [*shape, n] @@ -335,23 +325,8 @@ def test_discrete_conversion(n, device, shape): assert one_hot.is_in(categorical.to_one_hot(categorical.rand(shape))) -@pytest.mark.parametrize( - "ns", - [ - [ - 5, - ], - [5, 2, 3], - [4, 5, 1, 3], - ], -) -@pytest.mark.parametrize( - "shape", - [ - torch.Size([3]), - torch.Size([4, 5]), - ], -) +@pytest.mark.parametrize("ns", [[5], [5, 2, 3], [4, 5, 1, 3]]) +@pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) @pytest.mark.parametrize("device", get_default_devices()) def test_multi_discrete_conversion(ns, shape, device): categorical = MultiDiscreteTensorSpec(ns, device=device) @@ -845,32 +820,30 @@ def test_equality_ndbounded(self): device = "cpu" dtype = torch.float16 - ts = BoundedTensorSpec( - minimum=minimum, maximum=maximum, device=device, dtype=dtype - ) + ts = BoundedTensorSpec(low=minimum, high=maximum, device=device, dtype=dtype) ts_same = BoundedTensorSpec( - minimum=minimum, maximum=maximum, device=device, dtype=dtype + low=minimum, high=maximum, device=device, dtype=dtype ) assert ts == ts_same ts_other = BoundedTensorSpec( - minimum=minimum + 1, maximum=maximum, device=device, dtype=dtype + low=minimum + 1, high=maximum, device=device, dtype=dtype ) assert ts != ts_other ts_other = BoundedTensorSpec( - minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype + low=minimum, high=maximum + 1, device=device, dtype=dtype ) assert ts != ts_other ts_other = BoundedTensorSpec( - minimum=minimum, maximum=maximum, device="cpu:0", dtype=dtype + low=minimum, high=maximum, device="cpu:0", dtype=dtype ) assert ts != ts_other ts_other = BoundedTensorSpec( - minimum=minimum, maximum=maximum, device=device, dtype=torch.float64 + low=minimum, high=maximum, device=device, dtype=torch.float64 ) assert ts != ts_other @@ -1060,14 +1033,12 @@ def test_equality_composite(self): bounded_other = BoundedTensorSpec(0, 2, torch.Size((1,)), device, dtype) nd = BoundedTensorSpec( - minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype + low=minimum, high=maximum + 1, device=device, dtype=dtype ) nd_same = BoundedTensorSpec( - minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype - ) - _ = BoundedTensorSpec( - minimum=minimum, maximum=maximum + 3, device=device, dtype=dtype + low=minimum, high=maximum + 1, device=device, dtype=dtype ) + _ = BoundedTensorSpec(low=minimum, high=maximum + 3, device=device, dtype=dtype) # Equality tests ts = CompositeSpec(ts1=bounded) @@ -1164,7 +1135,7 @@ def test_one_hot_discrete_action_spec_rand(self): sample = action_spec.rand((100000,)) - sample_list = sample.argmax(-1) + sample_list = sample.long().argmax(-1) sample_list = [sum(sample_list == i).item() for i in range(10)] assert chisquare(sample_list).pvalue > 0.1 @@ -1244,14 +1215,7 @@ def test_ndbounded_shape(self): class TestExpand: - @pytest.mark.parametrize( - "shape1", - [ - None, - (4,), - (5, 4), - ], - ) + @pytest.mark.parametrize("shape1", [None, (4,), (5, 4)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_binary(self, shape1, shape2): spec = BinaryDiscreteTensorSpec( @@ -1373,14 +1337,7 @@ def test_composite(self): assert new_spec["spec7"].shape == torch.Size([4, *batch_size, 9]) assert new_spec["spec8"].shape == torch.Size([4, *batch_size, 9]) - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_discrete(self, shape1, shape2): spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) @@ -1402,14 +1359,7 @@ def test_discrete(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_multidiscrete(self, shape1, shape2): if shape1 is None: @@ -1437,14 +1387,7 @@ def test_multidiscrete(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_multionehot(self, shape1, shape2): if shape1 is None: @@ -1472,14 +1415,7 @@ def test_multionehot(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_onehot(self, shape1, shape2): if shape1 is None: @@ -1507,14 +1443,7 @@ def test_onehot(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_unbounded(self, shape1, shape2): if shape1 is None: @@ -1542,14 +1471,7 @@ def test_unbounded(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_unboundeddiscrete(self, shape1, shape2): if shape1 is None: @@ -1667,14 +1589,7 @@ def test_composite(self): assert item == spec_clone[key], key assert spec == spec.clone() - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_discrete( self, shape1, @@ -1683,14 +1598,7 @@ def test_discrete( assert spec == spec.clone() assert spec is not spec.clone() - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_multidiscrete( self, shape1, @@ -1705,14 +1613,7 @@ def test_multidiscrete( assert spec == spec.clone() assert spec is not spec.clone() - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_multionehot( self, shape1, @@ -1727,14 +1628,7 @@ def test_multionehot( assert spec == spec.clone() assert spec is not spec.clone() - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_onehot( self, shape1, @@ -1749,14 +1643,7 @@ def test_onehot( assert spec == spec.clone() assert spec is not spec.clone() - @pytest.mark.parametrize( - "shape1", - [ - None, - (), - (5,), - ], - ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_unbounded( self, shape1, @@ -1771,14 +1658,173 @@ def test_unbounded( assert spec == spec.clone() assert spec is not spec.clone() + @pytest.mark.parametrize("shape1", [None, (), (5,)]) + def test_unboundeddiscrete( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long) + assert spec == spec.clone() + assert spec is not spec.clone() + + +class TestUnbind: + @pytest.mark.parametrize("shape1", [(5, 4)]) + def test_binary(self, shape1): + spec = BinaryDiscreteTensorSpec( + n=4, shape=shape1, device="cpu", dtype=torch.bool + ) + assert spec == torch.stack(spec.unbind(0), 0) + with pytest.raises(ValueError): + spec.unbind(-1) + @pytest.mark.parametrize( - "shape1", + "shape1,mini,maxi", [ - None, - (), - (5,), + [(10,), -torch.ones([]), torch.ones([])], + [None, -torch.ones([10]), torch.ones([])], + [None, -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([10]), torch.ones([])], + [(10,), -torch.ones([10]), torch.ones([10])], ], ) + def test_bounded(self, shape1, mini, maxi): + spec = BoundedTensorSpec( + mini, maxi, shape=shape1, device="cpu", dtype=torch.bool + ) + assert spec == torch.stack(spec.unbind(0), 0) + with pytest.raises(ValueError): + spec.unbind(-1) + + def test_composite(self): + batch_size = (5,) + spec1 = BoundedTensorSpec( + -torch.ones([*batch_size, 10]), + torch.ones([*batch_size, 10]), + shape=( + *batch_size, + 10, + ), + device="cpu", + dtype=torch.bool, + ) + spec2 = BinaryDiscreteTensorSpec( + n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool + ) + spec3 = DiscreteTensorSpec( + n=4, shape=batch_size, device="cpu", dtype=torch.long + ) + spec4 = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long + ) + spec5 = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec6 = OneHotDiscreteTensorSpec( + n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec7 = UnboundedContinuousTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.float64, + ) + spec8 = UnboundedDiscreteTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.long, + ) + spec = CompositeSpec( + spec1=spec1, + spec2=spec2, + spec3=spec3, + spec4=spec4, + spec5=spec5, + spec6=spec6, + spec7=spec7, + spec8=spec8, + shape=batch_size, + ) + assert spec == torch.stack(spec.unbind(0), 0) + assert spec == torch.stack(spec.unbind(-1), -1) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_discrete( + self, + shape1, + ): + spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + assert spec == torch.stack(spec.unbind(0), 0) + assert spec == torch.stack(spec.unbind(-1), -1) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multidiscrete( + self, + shape1, + ): + if shape1 is None: + shape1 = (3,) + else: + shape1 = (*shape1, 3) + spec = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec == torch.stack(spec.unbind(0), 0) + with pytest.raises(ValueError): + spec.unbind(-1) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multionehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec == torch.stack(spec.unbind(0), 0) + with pytest.raises(ValueError): + spec.unbind(-1) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_onehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = OneHotDiscreteTensorSpec( + n=15, shape=shape1, device="cpu", dtype=torch.long + ) + assert spec == torch.stack(spec.unbind(0), 0) + with pytest.raises(ValueError): + spec.unbind(-1) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_unbounded( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedContinuousTensorSpec( + shape=shape1, device="cpu", dtype=torch.float64 + ) + assert spec == torch.stack(spec.unbind(0), 0) + assert spec == torch.stack(spec.unbind(-1), -1) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_unboundeddiscrete( self, shape1, @@ -1788,8 +1834,8 @@ def test_unboundeddiscrete( else: shape1 = (*shape1, 15) spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long) - assert spec == spec.clone() - assert spec is not spec.clone() + assert spec == torch.stack(spec.unbind(0), 0) + assert spec == torch.stack(spec.unbind(-1), -1) @pytest.mark.parametrize( @@ -2147,8 +2193,10 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): def test_to_numpy(self, shape, stack_dim): c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) - c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c = torch.stack([c1, c2], stack_dim) + torch.manual_seed(0) shape = list(shape) @@ -2164,14 +2212,124 @@ def test_to_numpy(self, shape, stack_dim): with pytest.raises(AssertionError): c.to_numpy(val + 1, safe=True) + def test_malformed_stack(self, shape, stack_dim): + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + with pytest.raises(RuntimeError, match="Dtypes differ"): + torch.stack([c1, c2], stack_dim) + + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c2 = UnboundedContinuousTensorSpec(shape=shape, dtype=torch.float32) + c3 = UnboundedDiscreteTensorSpec(shape=shape, dtype=torch.float32) + with pytest.raises( + RuntimeError, + match="Stacking specs cannot occur: Found more than one type of specs in the list.", + ): + torch.stack([c1, c2], stack_dim) + torch.stack([c3, c2], stack_dim) -class TestStackComposite: + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c2 = BoundedTensorSpec(-1, 1, shape=shape + (3,), dtype=torch.float32) + with pytest.raises(RuntimeError, match="Ndims differ"): + torch.stack([c1, c2], stack_dim) + + +class TestDenseStackedCompositeSpecs: def test_stack(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) c2 = c1.clone() c = torch.stack([c1, c2], 0) assert isinstance(c, CompositeSpec) + +class TestLazyStackedCompositeSpecs: + def _get_het_specs( + self, + batch_size=(), + stack_dim: int = 0, + ): + shared = BoundedTensorSpec(low=0, high=1, shape=(*batch_size, 32, 32, 3)) + hetero_3d = UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 3, + ) + ) + hetero_2d = UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 2, + ) + ) + lidar = BoundedTensorSpec( + low=0, + high=5, + shape=( + *batch_size, + 20, + ), + ) + + individual_0_obs = CompositeSpec( + { + "individual_0_obs_0": UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 3, + 1, + ) + ) + }, + shape=(*batch_size, 3), + ) + individual_1_obs = CompositeSpec( + { + "individual_1_obs_0": BoundedTensorSpec( + low=0, high=3, shape=(*batch_size, 3, 1, 2) + ) + }, + shape=(*batch_size, 3), + ) + individual_2_obs = CompositeSpec( + { + "individual_1_obs_0": UnboundedContinuousTensorSpec( + shape=(*batch_size, 3, 1, 2, 3) + ) + }, + shape=(*batch_size, 3), + ) + + spec_list = [ + CompositeSpec( + { + "shared": shared, + "lidar": lidar, + "hetero": hetero_3d, + "individual_0_obs": individual_0_obs, + }, + shape=batch_size, + ), + CompositeSpec( + { + "shared": shared, + "lidar": lidar, + "hetero": hetero_2d, + "individual_1_obs": individual_1_obs, + }, + shape=batch_size, + ), + CompositeSpec( + { + "shared": shared, + "hetero": hetero_2d, + "individual_2_obs": individual_2_obs, + }, + shape=batch_size, + ), + ] + + return torch.stack(spec_list, dim=stack_dim) + def test_stack_index(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) c2 = CompositeSpec( @@ -2428,6 +2586,387 @@ def test_to_numpy(self): with pytest.raises(AssertionError): c.to_numpy(td_fail, safe=True) + def test_unsqueeze(self): + c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 1) + for unsq in range(-2, 3): + cu = c.unsqueeze(unsq) + shape = list(c.shape) + new_unsq = unsq if unsq >= 0 else c.ndim + unsq + 1 + shape.insert(new_unsq, 1) + assert cu.shape == torch.Size(shape) + cus = cu.squeeze(unsq) + assert c.shape == cus.shape, unsq + assert cus == c + + assert c.squeeze().shape == torch.Size([2, 3]) + + c = self._get_het_specs() + cu = c.unsqueeze(0) + assert cu.shape == torch.Size([1, 3]) + cus = cu.squeeze(0) + assert cus == c + + @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) + def test_len(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + assert len(c) == c.shape[0] + assert len(c) == len(c.rand()) + + @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) + def test_eq(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + c2 = self._get_het_specs(batch_size=batch_size) + + assert c == c2 and not c != c2 + assert c == c.clone() and not c != c.clone() + + del c2["shared"] + assert not c == c2 and c != c2 + + c2 = self._get_het_specs(batch_size=batch_size) + del c2[0]["lidar"] + + assert not c == c2 and c != c2 + + c2 = self._get_het_specs(batch_size=batch_size) + c2[0]["lidar"].space.low += 1 + assert not c == c2 and c != c2 + + @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) + @pytest.mark.parametrize("include_nested", [True, False]) + @pytest.mark.parametrize("leaves_only", [True, False]) + def test_del(self, batch_size, include_nested, leaves_only): + c = self._get_het_specs(batch_size=batch_size) + td_c = c.rand() + + keys = list(c.keys(include_nested=include_nested, leaves_only=leaves_only)) + for k in keys: + del c[k] + del td_c[k] + assert len(c.keys(include_nested=include_nested, leaves_only=leaves_only)) == 0 + assert ( + len(td_c.keys(include_nested=include_nested, leaves_only=leaves_only)) == 0 + ) + + keys = list(c[0].keys(include_nested=include_nested, leaves_only=leaves_only)) + for k in keys: + del c[k] + del td_c[k] + assert ( + len(c[0].keys(include_nested=include_nested, leaves_only=leaves_only)) == 0 + ) + assert ( + len(td_c[0].keys(include_nested=include_nested, leaves_only=leaves_only)) + == 0 + ) + with pytest.raises(KeyError): + del c["individual_1_obs_0"] + with pytest.raises(KeyError): + del td_c["individual_1_obs_0"] + + del c[("individual_1_obs", "individual_1_obs_0")] + del td_c[("individual_1_obs", "individual_1_obs_0")] + + @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) + def test_is_in(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + td_c = c.rand() + assert c.is_in(td_c) + + del td_c["shared"] + with pytest.raises(KeyError): + assert not c.is_in(td_c) + + td_c = c.rand() + del td_c[("individual_1_obs", "individual_1_obs_0")] + with pytest.raises(KeyError): + assert not c.is_in(td_c) + + td_c = c.rand() + td_c["shared"] += 1 + assert not c.is_in(td_c) + + td_c = c.rand() + td_c[1]["individual_1_obs", "individual_1_obs_0"] += 4 + assert not c.is_in(td_c) + + td_c = c.rand() + td_c[0]["individual_0_obs", "individual_0_obs_0"] += 1 + assert c.is_in(td_c) + + def test_type_check(self): + c = self._get_het_specs() + td_c = c.rand() + + c.type_check(td_c) + c.type_check(td_c["shared"], "shared") + + @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) + def test_project(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + td_c = c.rand() + assert c.is_in(td_c) + val = c.project(td_c) + assert c.is_in(val) + + del td_c["shared"] + with pytest.raises(KeyError): + c.is_in(td_c) + + td_c = c.rand() + del td_c[("individual_1_obs", "individual_1_obs_0")] + with pytest.raises(KeyError): + c.is_in(td_c) + + td_c = c.rand() + td_c["shared"] += 1 + assert not c.is_in(td_c) + val = c.project(td_c) + assert c.is_in(val) + + td_c = c.rand() + td_c[1]["individual_1_obs", "individual_1_obs_0"] += 4 + assert not c.is_in(td_c) + val = c.project(td_c) + assert c.is_in(val) + + td_c = c.rand() + td_c[0]["individual_0_obs", "individual_0_obs_0"] += 1 + assert c.is_in(td_c) + + def test_repr(self): + c = self._get_het_specs() + + expected = f"""LazyStackedCompositeSpec( + fields={{ + hetero: LazyStackedUnboundedContinuousTensorSpec( + shape=torch.Size([3, -1]), device=cpu, dtype=torch.float32, domain=continuous), + shared: BoundedTensorSpec( + shape=torch.Size([3, 32, 32, 3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous)}}, + exclusive_fields={{ + 0 -> + lidar: BoundedTensorSpec( + shape=torch.Size([20]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + individual_0_obs: CompositeSpec( + individual_0_obs_0: UnboundedContinuousTensorSpec( + shape=torch.Size([3, 1]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([3])), + 1 -> + lidar: BoundedTensorSpec( + shape=torch.Size([20]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + individual_1_obs: CompositeSpec( + individual_1_obs_0: BoundedTensorSpec( + shape=torch.Size([3, 1, 2]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([3])), + 2 -> + individual_2_obs: CompositeSpec( + individual_1_obs_0: UnboundedContinuousTensorSpec( + shape=torch.Size([3, 1, 2, 3]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([3]))}}, + device=cpu, + shape={torch.Size((3,))}, + stack_dim={c.stack_dim})""" + assert expected == repr(c) + + c = c[0:2] + del c["individual_0_obs"] + del c["individual_1_obs"] + expected = f"""LazyStackedCompositeSpec( + fields={{ + hetero: LazyStackedUnboundedContinuousTensorSpec( + shape=torch.Size([2, -1]), device=cpu, dtype=torch.float32, domain=continuous), + lidar: BoundedTensorSpec( + shape=torch.Size([2, 20]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + shared: BoundedTensorSpec( + shape=torch.Size([2, 32, 32, 3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous)}}, + exclusive_fields={{ + }}, + device=cpu, + shape={torch.Size((2,))}, + stack_dim={c.stack_dim})""" + assert expected == repr(c) + + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + def test_consolidate_spec(self, batch_size): + spec = self._get_het_specs(batch_size) + spec_lazy = spec.clone() + + assert not check_no_exclusive_keys(spec_lazy) + + spec_lazy = consolidate_spec(spec_lazy, recurse_through_entries=False) + assert check_no_exclusive_keys(spec_lazy, recurse=False) + + spec_lazy = consolidate_spec(spec_lazy, recurse_through_entries=True) + assert check_no_exclusive_keys(spec_lazy, recurse=True) + + assert get_all_keys(spec, include_exclusive=True) == get_all_keys( + spec_lazy, include_exclusive=False + ) + + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) + def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): + shared = UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 5, + 5, + 5, + ) + ) + lazy_spec = torch.stack( + [ + UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 6, 7)), + UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 7, 7)), + UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)), + UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)), + ], + dim=len(batch_size), + ) + + spec_list = [ + CompositeSpec( + { + "shared": shared, + "lazy_spec": lazy_spec, + }, + shape=batch_size, + ), + CompositeSpec( + { + "shared": shared, + }, + shape=batch_size, + ), + CompositeSpec( + {}, + shape=batch_size, + device="cpu", + ), + ] + + spec = torch.stack(spec_list, dim=0) + spec_consolidated = consolidate_spec(spec) + + assert spec_consolidated["shared"].shape == (3, *batch_size, -1, -1, -1) + assert spec_consolidated["lazy_spec"].shape == (3, *batch_size, 4, 5, -1, 7) + + assert check_no_exclusive_keys(spec_consolidated, recurse=True) + assert get_all_keys(spec, include_exclusive=True) == get_all_keys( + spec_consolidated, include_exclusive=False + ) + + @pytest.mark.parametrize("batch_size", [(2,), (2, 1)]) + def test_update(self, batch_size, stack_dim=0): + spec = self._get_het_specs(batch_size, stack_dim) + spec2 = self._get_het_specs(batch_size, stack_dim) + + del spec2["shared"] + spec2["hetero"] = spec2["hetero"].unsqueeze(-1) + assert spec["hetero"].shape == (3, *batch_size, -1) + spec.update(spec2) + assert spec["hetero"].shape == (3, *batch_size, -1, 1) + + spec2[1]["individual_1_obs"]["individual_1_obs_0"].space.low += 1 + assert spec[1]["individual_1_obs"]["individual_1_obs_0"].space.low.sum() == 0 + spec.update(spec2) + assert ( + spec[1]["individual_1_obs"]["individual_1_obs_0"].space.low.sum() == 0 + ) # Only non exclusive keys will be updated + + new = torch.stack( + [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], 0 + ) + spec2["new"] = new + spec.update(spec2) + assert spec["new"] == new + + @pytest.mark.parametrize("batch_size", [(2,), (2, 1)]) + @pytest.mark.parametrize("stack_dim", [0, 1]) + def test_set_item(self, batch_size, stack_dim): + spec = self._get_het_specs(batch_size, stack_dim) + + new = torch.stack( + [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], + stack_dim, + ) + spec["new"] = new + assert spec["new"] == new + + new = new.unsqueeze(-1) + spec["new"] = new + assert spec["new"] == new + + new = new.squeeze(-1) + assert spec["new"] == new.unsqueeze(-1) + + spec[("other", "key")] = new + assert spec[("other", "key")] == new + assert isinstance(spec["other"], LazyStackedCompositeSpec) + + with pytest.raises(RuntimeError, match="key should be a Sequence"): + spec[0] = new + + comp = torch.stack( + [ + CompositeSpec( + {"a": UnboundedContinuousTensorSpec(shape=(*batch_size, i))}, + shape=batch_size, + ) + for i in range(3) + ], + stack_dim, + ) + spec["comp"] = comp + assert spec["comp"] == comp + assert spec["comp", "a"] == new + # MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080. @pytest.mark.parametrize( @@ -2648,6 +3187,148 @@ def test_composite_contains(): assert ("a", ("b", ("c",))) in spec.keys(True, True) +def get_all_keys(spec: TensorSpec, include_exclusive: bool): + """Given a TensorSpec, returns all exclusive and non-exclusive keys as a set of tuples. + + Args: + spec (TensorSpec): the spec to get keys from. + include_exclusive (bool: if True, include also exclusive keys in the result. + + """ + keys = set() + if isinstance(spec, LazyStackedCompositeSpec) and include_exclusive: + for t in spec._specs: + keys = keys.union(get_all_keys(t, include_exclusive)) + if isinstance(spec, CompositeSpec): + for key in spec.keys(): + keys.add((key,)) + inner_keys = get_all_keys(spec[key], include_exclusive) + for inner_key in inner_keys: + keys.add((key,) + _unravel_key_to_tuple(inner_key)) + + return keys + + +@pytest.mark.parametrize("shape", ((), (1,), (2, 3), (2, 3, 4))) +@pytest.mark.parametrize( + "spectype", ["one_hot", "categorical", "mult_one_hot", "mult_discrete"] +) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("rand_shape", ((), (2,), (2, 3))) +class TestSpecMasking: + def _make_mask(self, shape): + torch.manual_seed(0) + mask = torch.zeros(shape, dtype=torch.bool).bernoulli_() + if len(shape) == 1: + while not mask.any() or mask.all(): + mask = torch.zeros(shape, dtype=torch.bool).bernoulli_() + return mask + mask_view = mask.view(-1, shape[-1]) + for i in range(mask_view.shape[0]): + t = mask_view[i] + while not t.any() or t.all(): + t.copy_(torch.zeros_like(t).bernoulli_()) + return mask + + def _one_hot_spec(self, shape, device, n): + shape = torch.Size([*shape, n]) + mask = self._make_mask(shape).to(device) + return OneHotDiscreteTensorSpec(n, shape, device, mask=mask) + + def _mult_one_hot_spec(self, shape, device, n): + shape = torch.Size([*shape, n + n + 2]) + mask = torch.cat( + [ + self._make_mask(shape[:-1] + (n,)).to(device), + self._make_mask(shape[:-1] + (n + 2,)).to(device), + ], + -1, + ) + return MultiOneHotDiscreteTensorSpec([n, n + 2], shape, device, mask=mask) + + def _discrete_spec(self, shape, device, n): + mask = self._make_mask(torch.Size([*shape, n])).to(device) + return DiscreteTensorSpec(n, shape, device, mask=mask) + + def _mult_discrete_spec(self, shape, device, n): + shape = torch.Size([*shape, 2]) + mask = torch.cat( + [ + self._make_mask(shape[:-1] + (n,)).to(device), + self._make_mask(shape[:-1] + (n + 2,)).to(device), + ], + -1, + ) + return MultiDiscreteTensorSpec([n, n + 2], shape, device, mask=mask) + + def test_equal(self, shape, device, spectype, rand_shape, n=5): + shape = torch.Size(shape) + spec = ( + self._one_hot_spec(shape, device, n=n) + if spectype == "one_hot" + else self._discrete_spec(shape, device, n=n) + if spectype == "categorical" + else self._mult_one_hot_spec(shape, device, n=n) + if spectype == "mult_one_hot" + else self._mult_discrete_spec(shape, device, n=n) + if spectype == "mult_discrete" + else None + ) + spec_clone = spec.clone() + assert spec == spec_clone + assert spec.unsqueeze(0).squeeze(0) == spec + spec.update_mask(~spec.mask) + assert (spec.mask != spec_clone.mask).any() + assert spec != spec_clone + + def test_is_in(self, shape, device, spectype, rand_shape, n=5): + shape = torch.Size(shape) + rand_shape = torch.Size(rand_shape) + spec = ( + self._one_hot_spec(shape, device, n=n) + if spectype == "one_hot" + else self._discrete_spec(shape, device, n=n) + if spectype == "categorical" + else self._mult_one_hot_spec(shape, device, n=n) + if spectype == "mult_one_hot" + else self._mult_discrete_spec(shape, device, n=n) + if spectype == "mult_discrete" + else None + ) + s = spec.rand(rand_shape) + assert spec.is_in(s) + spec.update_mask(~spec.mask) + assert not spec.is_in(s) + + def test_project(self, shape, device, spectype, rand_shape, n=5): + shape = torch.Size(shape) + rand_shape = torch.Size(rand_shape) + spec = ( + self._one_hot_spec(shape, device, n=n) + if spectype == "one_hot" + else self._discrete_spec(shape, device, n=n) + if spectype == "categorical" + else self._mult_one_hot_spec(shape, device, n=n) + if spectype == "mult_one_hot" + else self._mult_discrete_spec(shape, device, n=n) + if spectype == "mult_discrete" + else None + ) + s = spec.rand(rand_shape) + assert (spec.project(s) == s).all() + spec.update_mask(~spec.mask) + sp = spec.project(s) + assert sp.shape == s.shape + if spectype == "one_hot": + assert (sp != s).any(-1).all() + assert (sp.any(-1)).all() + elif spectype == "mult_one_hot": + assert (sp != s).any(-1).all() + assert (sp.sum(-1) == 2).all() + else: + assert (sp != s).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ddcdf0f7535..4e1fbfcd1c1 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -7,8 +7,14 @@ import pytest import torch -from tensordict import TensorDict, unravel_key_list -from tensordict.nn import InteractionType, make_functional, TensorDictModule +from mocking_classes import DiscreteActionVecMockEnv +from tensordict import pad, TensorDict, unravel_key_list +from tensordict.nn import ( + InteractionType, + make_functional, + TensorDictModule, + TensorDictSequential, +) from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -18,12 +24,20 @@ from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( AdditiveGaussianWrapper, + DecisionTransformerInferenceWrapper, + DTActor, + GRUModule, LSTMModule, + MLP, NormalParamWrapper, + OnlineDTActor, + ProbabilisticActor, SafeModule, + TanhDelta, TanhNormal, ValueOperator, ) +from torchrl.modules.models.decision_transformer import _has_transformers from torchrl.modules.tensordict_module.common import ( ensure_tensordict_compatible, is_tensordict_compatible, @@ -857,7 +871,8 @@ def test_functional(self, safe, spec_type): assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - del params["module", "2"] + with params.unlock_(): + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -939,7 +954,8 @@ def test_functional_probabilistic(self, safe, spec_type): assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 4 del tdmodule[3] - del params["module", "3"] + with params.unlock_(): + del params["module", "3"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__getitem__") @@ -1014,7 +1030,8 @@ def test_functional_with_buffer(self, safe, spec_type): assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - del params["module", "2"] + with params.unlock_(): + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1103,7 +1120,8 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 4 del tdmodule[3] - del params["module", "3"] + with params.unlock_(): + del params["module", "3"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__getitem__") @@ -1184,7 +1202,8 @@ def test_vmap(self, safe, spec_type): assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - del params["module", "2"] + with params.unlock_(): + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1627,13 +1646,31 @@ def test_set_temporal_mode(self): out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], ) assert lstm_module.set_recurrent_mode(False) is lstm_module - assert not lstm_module.set_recurrent_mode(False).temporal_mode + assert not lstm_module.set_recurrent_mode(False).recurrent_mode assert lstm_module.set_recurrent_mode(True) is not lstm_module - assert lstm_module.set_recurrent_mode(True).temporal_mode + assert lstm_module.set_recurrent_mode(True).recurrent_mode assert set(lstm_module.set_recurrent_mode(True).parameters()) == set( lstm_module.parameters() ) + def test_noncontiguous(self): + lstm_module = LSTMModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["bork", "h0", "h1"], + out_keys=["dork", ("next", "h0"), ("next", "h1")], + ) + td = TensorDict( + { + "bork": torch.randn(3, 3), + "is_init": torch.zeros(3, 1, dtype=torch.bool), + }, + [3], + ) + padded = pad(td, [0, 5]) + lstm_module(padded) + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) def test_singel_step(self, shape): td = TensorDict( @@ -1736,6 +1773,312 @@ def test_multi_consecutive(self, shape): td_ss["intermediate"], td["intermediate"][..., -1, :] ) + def test_lstm_parallel_env(self): + from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + + # tests that hidden states are carried over with parallel envs + lstm_module = LSTMModule( + input_size=7, + hidden_size=12, + num_layers=2, + in_key="observation", + out_key="features", + ) + + def create_transformed_env(): + primer = lstm_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv(categorical_action_encoding=True) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + env = ParallelEnv( + create_env_fn=create_transformed_env, + num_workers=2, + ) + + mlp = TensorDictModule( + MLP( + in_features=12, + out_features=7, + num_cells=[], + ), + in_keys=["features"], + out_keys=["logits"], + ) + + actor_model = TensorDictSequential(lstm_module, mlp) + + actor = ProbabilisticActor( + module=actor_model, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=True, + ) + for break_when_any_done in [False, True]: + data = env.rollout(10, actor, break_when_any_done=break_when_any_done) + assert (data.get("recurrent_state_c") != 0.0).any() + assert (data.get(("next", "recurrent_state_c")) != 0.0).all() + + +class TestGRUModule: + def test_errs(self): + with pytest.raises(ValueError, match="batch_first"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=False, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(ValueError, match="in_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=[ + "observation", + "hidden0", + "hidden1", + ], + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(TypeError, match="incompatible function arguments"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys="abc", + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(ValueError, match="in_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_key="smth", + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(ValueError, match="out_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden"), "other"], + ) + with pytest.raises(TypeError, match="incompatible function arguments"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys="abc", + ) + with pytest.raises(ValueError, match="out_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_key="smth", + out_keys=["intermediate", ("next", "hidden"), "other"], + ) + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + td = TensorDict({"observation": torch.randn(3)}, []) + with pytest.raises(KeyError, match="is_init"): + gru_module(td) + + def test_set_temporal_mode(self): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + assert gru_module.set_recurrent_mode(False) is gru_module + assert not gru_module.set_recurrent_mode(False).recurrent_mode + assert gru_module.set_recurrent_mode(True) is not gru_module + assert gru_module.set_recurrent_mode(True).recurrent_mode + assert set(gru_module.set_recurrent_mode(True).parameters()) == set( + gru_module.parameters() + ) + + def test_noncontiguous(self): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["bork", "h"], + out_keys=["dork", ("next", "h")], + ) + td = TensorDict( + { + "bork": torch.randn(3, 3), + "is_init": torch.zeros(3, 1, dtype=torch.bool), + }, + [3], + ) + padded = pad(td, [0, 5]) + gru_module(padded) + + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) + def test_singel_step(self, shape): + td = TensorDict( + { + "observation": torch.zeros(*shape, 3), + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), + }, + shape, + ) + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + td = gru_module(td) + td_next = step_mdp(td, keep_other=True) + td_next = gru_module(td_next) + + assert not torch.isclose(td_next["next", "hidden"], td["next", "hidden"]).any() + + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) + @pytest.mark.parametrize("t", [1, 10]) + def test_single_step_vs_multi(self, shape, t): + td = TensorDict( + { + "observation": torch.arange(t, dtype=torch.float32) + .unsqueeze(-1) + .expand(*shape, t, 3), + "is_init": torch.zeros(*shape, t, 1, dtype=torch.bool), + }, + [*shape, t], + ) + gru_module_ss = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + gru_module_ms = gru_module_ss.set_recurrent_mode() + gru_module_ms(td) + td_ss = TensorDict( + { + "observation": torch.zeros(*shape, 3), + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), + }, + shape, + ) + for _t in range(t): + gru_module_ss(td_ss) + td_ss = step_mdp(td_ss, keep_other=True) + td_ss["observation"][:] = _t + 1 + torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :]) + + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) + def test_multi_consecutive(self, shape): + t = 20 + td = TensorDict( + { + "observation": torch.arange(t, dtype=torch.float32) + .unsqueeze(-1) + .expand(*shape, t, 3), + "is_init": torch.zeros(*shape, t, 1, dtype=torch.bool), + }, + [*shape, t], + ) + if shape: + td["is_init"][0, ..., 13, :] = True + else: + td["is_init"][13, :] = True + + gru_module_ss = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + gru_module_ms = gru_module_ss.set_recurrent_mode() + gru_module_ms(td) + td_ss = TensorDict( + { + "observation": torch.zeros(*shape, 3), + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), + }, + shape, + ) + for _t in range(t): + td_ss["is_init"][:] = td["is_init"][..., _t, :] + gru_module_ss(td_ss) + td_ss = step_mdp(td_ss, keep_other=True) + td_ss["observation"][:] = _t + 1 + torch.testing.assert_close( + td_ss["intermediate"], td["intermediate"][..., -1, :] + ) + + def test_gru_parallel_env(self): + from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + + # tests that hidden states are carried over with parallel envs + gru_module = GRUModule( + input_size=7, + hidden_size=12, + num_layers=2, + in_key="observation", + out_key="features", + ) + + def create_transformed_env(): + primer = gru_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv(categorical_action_encoding=True) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + env = ParallelEnv( + create_env_fn=create_transformed_env, + num_workers=2, + ) + + mlp = TensorDictModule( + MLP( + in_features=12, + out_features=7, + num_cells=[], + ), + in_keys=["features"], + out_keys=["logits"], + ) + + actor_model = TensorDictSequential(gru_module, mlp) + + actor = ProbabilisticActor( + module=actor_model, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=True, + ) + for break_when_any_done in [False, True]: + data = env.rollout(10, actor, break_when_any_done=break_when_any_done) + assert (data.get("recurrent_state") != 0.0).any() + assert (data.get(("next", "recurrent_state")) != 0.0).all() + def test_safe_specs(): @@ -1786,6 +2129,73 @@ def test_vmapmodule(): assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() +@pytest.mark.skipif( + not _has_transformers, reason="transformers needed to test DT classes" +) +class TestDecisionTransformerInferenceWrapper: + @pytest.mark.parametrize("online", [True, False]) + def test_dt_inference_wrapper(self, online): + action_key = ("nested", ("action",)) + if online: + dtactor = OnlineDTActor( + state_dim=4, action_dim=2, transformer_config=DTActor.default_config() + ) + in_keys = ["loc", "scale"] + actor_module = TensorDictModule( + dtactor, + in_keys=["observation", action_key, "return_to_go"], + out_keys=in_keys, + ) + dist_class = TanhNormal + else: + dtactor = DTActor( + state_dim=4, action_dim=2, transformer_config=DTActor.default_config() + ) + in_keys = ["param"] + actor_module = TensorDictModule( + dtactor, + in_keys=["observation", action_key, "return_to_go"], + out_keys=in_keys, + ) + dist_class = TanhDelta + dist_kwargs = { + "min": -1.0, + "max": 1.0, + } + actor = ProbabilisticActor( + in_keys=in_keys, + out_keys=[action_key], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + ) + inference_actor = DecisionTransformerInferenceWrapper(actor) + sequence_length = 20 + td = TensorDict( + { + "observation": torch.randn(1, sequence_length, 4), + action_key: torch.randn(1, sequence_length, 2), + "return_to_go": torch.randn(1, sequence_length, 1), + }, + [1], + ) + with pytest.raises( + ValueError, + match="The action key action was not found in the policy out_keys", + ): + result = inference_actor(td) + inference_actor.set_tensor_keys(action=action_key) + result = inference_actor(td) + # checks that the seq length has disappeared + assert result.get(action_key).shape == torch.Size([1, 2]) + assert inference_actor.out_keys == unravel_key_list( + sorted([action_key, *in_keys, "observation", "return_to_go"], key=str) + ) + assert set(result.keys(True, True)) - set(td.keys(True, True)) == set( + inference_actor.out_keys + ) - set(inference_actor.in_keys) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_transforms.py b/test/test_transforms.py index ce274c56f82..ef6796ea04d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6,6 +6,8 @@ import argparse import itertools +import pickle +import sys from copy import copy from functools import partial @@ -19,38 +21,51 @@ HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED, + rand_reset, retry, ) from mocking_classes import ( ContinuousActionVecMockEnv, CountingBatchedEnv, CountingEnvCountPolicy, + DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, NestedCountingEnv, ) from tensordict import unravel_key from tensordict.nn import TensorDictSequential from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.utils import _unravel_key_to_tuple from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import prod from torchrl.data import ( BoundedTensorSpec, CompositeSpec, + LazyMemmapStorage, LazyTensorStorage, ReplayBuffer, TensorDictReplayBuffer, + TensorStorage, UnboundedContinuousTensorSpec, ) from torchrl.envs import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, CenterCrop, + ClipTransform, Compose, + DeviceCastTransform, DiscreteActionProjection, + DMControlEnv, DoubleToFloat, + EndOfLifeTransform, EnvBase, EnvCreator, ExcludeTransform, @@ -63,6 +78,7 @@ NoopResetEnv, ObservationNorm, ParallelEnv, + PermuteTransform, PinMemoryTransform, R3MTransform, RandomCropTensorDict, @@ -85,15 +101,16 @@ VC1Transform, VIPTransform, ) -from torchrl.envs.libs.gym import _has_gym, GymEnv +from torchrl.envs.libs.dm_control import _has_dm_control +from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend from torchrl.envs.transforms import VecNorm from torchrl.envs.transforms.r3m import _R3MNet from torchrl.envs.transforms.rlhf import KLRewardTransform -from torchrl.envs.transforms.transforms import _has_tv +from torchrl.envs.transforms.transforms import _has_tv, FORWARD_NOT_IMPLEMENTED from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform -from torchrl.envs.utils import check_env_specs, step_mdp -from torchrl.modules import ProbabilisticActor, TanhNormal +from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp +from torchrl.modules import LSTMModule, MLP, ProbabilisticActor, TanhNormal TIMEOUT = 100.0 @@ -245,7 +262,7 @@ def test_nested(self): orig_env = NestedCountingEnv() env = TransformedEnv(orig_env, BinarizeReward(in_keys=[orig_env.reward_key])) env.rollout(3) - assert "data" in env._output_spec["_reward_spec"] + assert "data" in env._output_spec["full_reward_spec"] def test_transform_compose(self): torch.manual_seed(0) @@ -318,6 +335,253 @@ def test_transform_inverse(self): raise pytest.skip("No inverse for BinerizedReward") +class TestClipTransform(TransformBase): + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + device = "cpu" + batch = [20] + torch.manual_seed(0) + rb = rbclass(storage=LazyTensorStorage(20)) + + t = Compose( + ClipTransform( + in_keys=["observation", "reward"], + out_keys=["obs_clip", "reward_clip"], + in_keys_inv=["input"], + out_keys_inv=["input_clip"], + low=-0.1, + high=0.1, + ) + ) + rb.append_transform(t) + data = TensorDict({"observation": 1, "reward": 2, "input": 3}, []) + rb.add(data) + sample = rb.sample(20) + + assert (sample["observation"] == 1).all() + assert (sample["obs_clip"] == 0.1).all() + assert (sample["reward"] == 2).all() + assert (sample["reward_clip"] == 0.1).all() + assert (sample["input"] == 3).all() + assert (sample["input_clip"] == 0.1).all() + + def test_single_trans_env_check(self): + env = ContinuousActionVecMockEnv() + env = TransformedEnv( + env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-0.1, + high=0.1, + ), + ) + check_env_specs(env) + + def test_transform_compose(self): + t = Compose( + ClipTransform( + in_keys=["observation", "reward"], + out_keys=["obs_clip", "reward_clip"], + low=-0.1, + high=0.1, + ) + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t(data) + assert data["observation"] == 1 + assert data["obs_clip"] == 0.1 + assert data["reward"] == 2 + assert data["reward_clip"] == 0.1 + + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_env(self, device): + base_env = ContinuousActionVecMockEnv(device=device) + env = TransformedEnv( + base_env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-0.1, + high=0.1, + ), + ) + r = env.rollout(3) + assert r.device == device + assert (r["observation"] <= 0.1).all() + assert (r["next", "observation"] <= 0.1).all() + assert (r["next", "reward"] <= 0.1).all() + assert (r["observation"] >= -0.1).all() + assert (r["next", "observation"] >= -0.1).all() + assert (r["next", "reward"] >= -0.1).all() + check_env_specs(env) + with pytest.raises( + TypeError, match="Either one or both of `high` and `low` must be provided" + ): + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=None, + high=None, + ) + with pytest.raises(TypeError, match="low and high must be scalars or None"): + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=torch.randn(2), + high=None, + ) + with pytest.raises(ValueError, match="`low` must be stricly lower than `high`"): + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=1.0, + high=-1.0, + ) + env = TransformedEnv( + base_env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=1.0, + high=None, + ), + ) + check_env_specs(env) + env = TransformedEnv( + base_env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=None, + high=1.0, + ), + ) + check_env_specs(env) + env = TransformedEnv( + base_env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-1, + high=1, + ), + ) + check_env_specs(env) + env = TransformedEnv( + base_env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-torch.ones(()), + high=1, + ), + ) + check_env_specs(env) + + def test_transform_inverse(self): + t = ClipTransform( + in_keys_inv=["observation", "reward"], + out_keys_inv=["obs_clip", "reward_clip"], + low=-0.1, + high=0.1, + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t.inv(data) + assert data["observation"] == 1 + assert data["obs_clip"] == 0.1 + assert data["reward"] == 2 + assert data["reward_clip"] == 0.1 + + def test_transform_model(self): + t = nn.Sequential( + ClipTransform( + in_keys=["observation", "reward"], + out_keys=["obs_clip", "reward_clip"], + low=-0.1, + high=0.1, + ) + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t(data) + assert data["observation"] == 1 + assert data["obs_clip"] == 0.1 + assert data["reward"] == 2 + assert data["reward_clip"] == 0.1 + + def test_transform_no_env(self): + t = ClipTransform( + in_keys=["observation", "reward"], + out_keys=["obs_clip", "reward_clip"], + low=-0.1, + high=0.1, + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t(data) + assert data["observation"] == 1 + assert data["obs_clip"] == 0.1 + assert data["reward"] == 2 + assert data["reward_clip"] == 0.1 + + def test_parallel_trans_env_check(self): + def make_env(): + env = ContinuousActionVecMockEnv() + return TransformedEnv( + env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-0.1, + high=0.1, + ), + ) + + env = ParallelEnv(2, make_env) + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + env = ContinuousActionVecMockEnv() + return TransformedEnv( + env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-0.1, + high=0.1, + ), + ) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_trans_parallel_env_check(self): + env = ContinuousActionVecMockEnv() + env = TransformedEnv( + ParallelEnv(2, ContinuousActionVecMockEnv), + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-0.1, + high=0.1, + ), + ) + check_env_specs(env) + + def test_trans_serial_env_check(self): + env = ContinuousActionVecMockEnv() + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-0.1, + high=0.1, + ), + ) + check_env_specs(env) + + class TestCatFrames(TransformBase): @pytest.mark.parametrize("out_keys", [None, ["obs2"]]) def test_single_trans_env_check(self, out_keys): @@ -375,10 +639,10 @@ def test_nested(self, nested_dim=3, batch_size=(32, 1), rollout_length=6, cat_N= nested_dim, 1, ) - tranformed_env = TransformedEnv( + transformed_env = TransformedEnv( env, CatFrames(dim=-1, N=cat_N, in_keys=[("data", "states")]) ) - td = tranformed_env.rollout(rollout_length, policy=policy) + td = transformed_env.rollout(rollout_length, policy=policy) assert td[("data", "states")].shape == ( *batch_size, rollout_length, @@ -426,10 +690,9 @@ def test_transform_env_clone(self): value_at_clone = td["next", "observation"].clone() for _ in range(10): td = env.rand_step(td) - assert (td["next", "observation"] != value_at_clone).any() - assert ( - td["next", "observation"] == env.transform._cat_buffers_observation - ).all() + td = step_mdp(td) + assert (td["observation"] != value_at_clone).any() + assert (td["observation"] == env.transform._cat_buffers_observation).all() assert ( cloned._cat_buffers_observation == env.transform._cat_buffers_observation ).all() @@ -592,10 +855,10 @@ def test_catframes_transform_observation_spec(self): for key in keys: for i in range(N): assert torch.equal( - result[key].space.maximum[i], observation_spec[key].space.maximum[0] + result[key].space.high[i], observation_spec[key].space.high[0] ) assert torch.equal( - result[key].space.minimum[i], observation_spec[key].space.minimum[0] + result[key].space.low[i], observation_spec[key].space.low[0] ) @pytest.mark.parametrize("device", get_default_devices()) @@ -614,7 +877,7 @@ def test_transform_no_env(self, device, d, batch_size, dim, N): td = TensorDict(dict(zip(keys, key_tensors)), batch_size, device=device) if dim > 0: with pytest.raises( - ValueError, match="dim must be > 0 to accomodate for tensordict" + ValueError, match="dim must be < 0 to accomodate for tensordict" ): cat_frames = CatFrames(N=N, in_keys=keys, dim=dim) return @@ -934,25 +1197,26 @@ def test_transform_env(self, model, tensor_pixels_key, device): transformed_env = TransformedEnv(base_env, r3m) td = transformed_env.reset() assert td.device == device - exp_keys = {"vec", "done", "pixels_orig"} + expected_keys = {"vec", "done", "pixels_orig", "terminated"} if tensor_pixels_key: - exp_keys.add(tensor_pixels_key[0]) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + expected_keys.add(tensor_pixels_key[0]) + assert set(td.keys()) == expected_keys, set(td.keys()) - expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), "next", } ) if tensor_pixels_key: - exp_keys.add(("next", tensor_pixels_key[0])) - assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys + expected_keys.add(("next", tensor_pixels_key[0])) + assert set(td.keys(True)) == expected_keys, set(td.keys(True)) - expected_keys transformed_env.close() @pytest.mark.parametrize("stack_images", [True, False]) @@ -993,32 +1257,33 @@ def base_env_constructor(): td = transformed_env.reset() assert td.device == device if stack_images: - exp_keys = {"pixels_orig", "done", "vec"} + expected_keys = {"pixels_orig", "done", "vec", "terminated"} # assert td["vec"].shape[0] == 2 assert td["vec"].ndimension() == 1 + parallel - assert set(td.keys()) == exp_keys + assert set(td.keys()) == expected_keys else: - exp_keys = {"pixels_orig", "done", "vec", "vec2"} + expected_keys = {"pixels_orig", "done", "vec", "vec2", "terminated"} assert td["vec"].shape[0 + parallel] != 2 assert td["vec"].ndimension() == 1 + parallel assert td["vec2"].shape[0 + parallel] != 2 assert td["vec2"].ndimension() == 1 + parallel - assert set(td.keys()) == exp_keys + assert set(td.keys()) == expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), "next", } ) if not stack_images: - exp_keys.add(("next", "vec2")) - assert set(td.keys(True)) == exp_keys, set(td.keys()) - exp_keys + expected_keys.add(("next", "vec2")) + assert set(td.keys(True)) == expected_keys, set(td.keys()) - expected_keys transformed_env.close() def test_r3m_parallel(self, model, device): @@ -1036,23 +1301,24 @@ def test_r3m_parallel(self, model, device): td = transformed_env.reset() assert td.device == device assert td.batch_size == torch.Size([4]) - exp_keys = {"vec", "done", "pixels_orig"} + expected_keys = {"vec", "done", "pixels_orig", "terminated"} if tensor_pixels_key: - exp_keys.add(tensor_pixels_key) - assert set(td.keys(True)) == exp_keys + expected_keys.add(tensor_pixels_key) + assert set(td.keys(True)) == expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), "next", } ) - assert set(td.keys(True)) == exp_keys, set(td.keys()) - exp_keys + assert set(td.keys(True)) == expected_keys, set(td.keys()) - expected_keys transformed_env.close() del transformed_env @@ -1121,12 +1387,48 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device): + list(transformed_env.observation_spec.keys()) + ["action"] + [("next", key) for key in transformed_env.observation_spec.keys()] - + [("next", "reward"), ("next", "done"), "done", "next"] + + [ + ("next", "reward"), + ("next", "done"), + ("next", "terminated"), + "terminated", + "done", + "next", + ] ) assert set(expected_keys) == set(transformed_env.rollout(3).keys(True)) class TestStepCounter(TransformBase): + @pytest.mark.skipif(not _has_gym, reason="no gym detected") + def test_step_count_gym(self): + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), StepCounter(max_steps=30)) + env.rollout(1000) + check_env_specs(env) + + @pytest.mark.skipif(not _has_dm_control, reason="no dm_control detected") + def test_step_count_dmc(self): + env = TransformedEnv(DMControlEnv("cheetah", "run"), StepCounter(max_steps=30)) + env.rollout(1000) + check_env_specs(env) + + @pytest.mark.parametrize("update_done", [False, True]) + @pytest.mark.parametrize("max_steps", [10, None]) + def test_single_trans_env_check(self, update_done, max_steps): + env = TransformedEnv( + ContinuousActionVecMockEnv(), + StepCounter(max_steps=max_steps, update_done=update_done), + ) + check_env_specs(env) + r = env.rollout(100, break_when_any_done=False) + if update_done and max_steps: + assert r["next", "done"][r["next", "truncated"]].all() + elif max_steps: + assert not r["next", "done"][r["next", "truncated"]].all() + else: + assert "truncated" not in r.keys() + assert ("next", "truncated") not in r.keys(True) + def test_parallel_trans_env_check(self): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(10)) @@ -1151,10 +1453,6 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), StepCounter(10)) check_env_specs(env) - def test_single_trans_env_check(self): - env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(10)) - check_env_specs(env) - @pytest.mark.skipif(not _has_gym, reason="Gym not found") def test_transform_env(self): env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), StepCounter(10)) @@ -1162,7 +1460,7 @@ def test_transform_env(self): assert td["step_count"].max() == 9 assert td.shape[-1] == 100 - @pytest.mark.parametrize("step_key", ["step_count", ("other", "key")]) + @pytest.mark.parametrize("step_key", ["step_count", "other-key"]) @pytest.mark.parametrize("max_steps", [None, 10]) @pytest.mark.parametrize("nested_done", [True, False]) def test_nested( @@ -1177,9 +1475,11 @@ def test_nested( transformed_env = TransformedEnv( env, StepCounter( - max_steps=max_steps, step_count_key=step_key, truncated_key=env.done_key + max_steps=max_steps, + step_count_key=step_key, ), ) + step_key = transformed_env.transform.step_count_keys[0] td = transformed_env.rollout( rollout_length, policy=policy, break_when_any_done=False ) @@ -1195,10 +1495,17 @@ def test_nested( assert step[:max_steps].eq(torch.arange(max_steps)).all() assert step[max_steps:].eq(torch.arange(rollout_length - max_steps)).all() - _reset = env.done_spec.rand() + if nested_done: + for done_key in env.done_keys: + reset_key = (*done_key[:-1], "_reset") + _reset = env.full_done_spec[done_key].rand() + break + else: + reset_key = "_reset" + _reset = env.full_done_spec["done"].rand() td_reset = transformed_env.reset( TensorDict( - {"_reset": _reset, step_key: last_step}, + {reset_key: _reset, step_key: last_step}, batch_size=env.batch_size, device=env.device, ) @@ -1232,13 +1539,21 @@ def test_transform_compose(self, max_steps, device, batch, reset_workers): _reset = torch.randn(done.shape, device=device) < 0 td.set("_reset", _reset) td.set("done", _reset) + td.set("terminated", _reset) + td.set(("next", "terminated"), done) td.set(("next", "done"), done) - + td.set("step_count", torch.zeros(*batch, 1, dtype=torch.int)) + step_counter[0]._step_count_keys = ["step_count"] + step_counter[0]._terminated_keys = ["terminated"] + step_counter[0]._truncated_keys = ["truncated"] + step_counter[0]._reset_keys = ["_reset"] + step_counter[0]._done_keys = ["done"] td = step_counter.reset(td) assert not torch.all(td.get("step_count")) i = 0 while max_steps is None or i < max_steps: - td = step_counter._step(td) + next_td = step_counter._step(td, td.get("next")) + td.set("next", next_td) i += 1 assert torch.all(td.get(("next", "step_count")) == i), ( td.get(("next", "step_count")), @@ -1246,6 +1561,7 @@ def test_transform_compose(self, max_steps, device, batch, reset_workers): ) td = step_mdp(td) td["next", "done"] = done + td["next", "terminated"] = done if max_steps is None: break @@ -1282,16 +1598,25 @@ def test_transform_no_env(self, max_steps, device, batch, reset_workers): td = TensorDict({"done": done, ("next", "done"): done}, batch, device=device) _reset = torch.zeros((), dtype=torch.bool, device=device) while not _reset.any() and reset_workers: - _reset = torch.randn(batch, device=device) < 0 + _reset = torch.randn(done.shape, device=device) < 0 td.set("_reset", _reset) + td.set("terminated", _reset) + td.set(("next", "terminated"), done) td.set("done", _reset) td.set(("next", "done"), done) + td.set("step_count", torch.zeros(*batch, 1, dtype=torch.int)) + step_counter._step_count_keys = ["step_count"] + step_counter._done_keys = ["done"] + step_counter._terminated_keys = ["terminated"] + step_counter._truncated_keys = ["truncated"] + step_counter._reset_keys = ["_reset"] + step_counter._completed_keys = ["completed"] td = step_counter.reset(td) assert not torch.all(td.get("step_count")) i = 0 while max_steps is None or i < max_steps: - td = step_counter._step(td) + td.set("next", step_counter._step(td, td.get("next"))) i += 1 assert torch.all(td.get(("next", "step_count")) == i), ( td.get(("next", "step_count")), @@ -1299,6 +1624,7 @@ def test_transform_no_env(self, max_steps, device, batch, reset_workers): ) td = step_mdp(td) td["next", "done"] = done + td["next", "terminated"] = done if max_steps is None: break @@ -1934,16 +2260,10 @@ def test_double2float(self, keys, keys_inv, device): if len(keys_total) == 1 and len(keys_inv) and keys[0] == "action": action_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) input_spec = CompositeSpec( - _action_spec=CompositeSpec(action=action_spec), _state_spec=None + full_action_spec=CompositeSpec(action=action_spec), full_state_spec=None ) action_spec = double2float.transform_input_spec(input_spec) assert action_spec.dtype == torch.float - - elif len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) - observation_spec = double2float.transform_observation_spec(observation_spec) - assert observation_spec.dtype == torch.float - else: observation_spec = CompositeSpec( { @@ -1953,7 +2273,71 @@ def test_double2float(self, keys, keys_inv, device): ) observation_spec = double2float.transform_observation_spec(observation_spec) for key in keys: - assert observation_spec[key].dtype == torch.float + assert observation_spec[key].dtype == torch.float, key + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "keys", + [ + ["observation", ("some_other", "nested_key")], + ["observation_pixels"], + ["action"], + ], + ) + @pytest.mark.parametrize( + "keys_inv", + [ + ["action", ("some_other", "nested_key")], + ["action"], + [], + ], + ) + def test_double2float_auto(self, keys, keys_inv, device): + torch.manual_seed(0) + double2float = DoubleToFloat() + d = { + key: torch.zeros(1, 3, 3, dtype=torch.double, device=device) for key in keys + } + d.update( + { + key: torch.zeros(1, 3, 3, dtype=torch.float32, device=device) + for key in keys_inv + } + ) + td = TensorDict(d, [1], device=device) + # check that the transform does change the dtype in forward + double2float(td) + for key in keys: + assert td.get(key).dtype == torch.float + + # check that inv does not affect the tensordict in-place + td = td.apply(lambda x: x.float()) + td_modif = double2float.inv(td) + for key in keys_inv: + assert td.get(key).dtype != torch.double + assert td_modif.get(key).dtype == torch.double + + def test_single_env_no_inkeys(self): + base_env = ContinuousActionVecMockEnv() + for key, spec in list(base_env.observation_spec.items(True, True)): + base_env.observation_spec[key] = spec.to(torch.float64) + for key, spec in list(base_env.state_spec.items(True, True)): + base_env.state_spec[key] = spec.to(torch.float64) + if base_env.action_spec.dtype == torch.float32: + base_env.action_spec = base_env.action_spec.to(torch.float64) + check_env_specs(base_env) + env = TransformedEnv( + base_env, + DoubleToFloat(), + ) + for spec in env.observation_spec.values(True, True): + assert spec.dtype == torch.float32 + for spec in env.state_spec.values(True, True): + assert spec.dtype == torch.float32 + assert env.action_spec.dtype != torch.float64 + assert env.transform.in_keys == env.transform.out_keys + assert env.transform.in_keys_inv == env.transform.out_keys_inv + check_env_specs(env) def test_single_trans_env_check(self, dtype_fixture): # noqa: F811 env = TransformedEnv( @@ -2178,6 +2562,24 @@ def test_transform_env(self): assert "b" in env.reset().keys() assert "c" in env.reset().keys() + def test_exclude_done(self): + base_env = TestExcludeTransform.EnvWithManyKeys() + env = TransformedEnv(base_env, ExcludeTransform("a", "done")) + assert "done" not in env.done_keys + check_env_specs(env) + env = TransformedEnv(base_env, ExcludeTransform("a")) + assert "done" in env.done_keys + check_env_specs(env) + + def test_exclude_reward(self): + base_env = TestExcludeTransform.EnvWithManyKeys() + env = TransformedEnv(base_env, ExcludeTransform("a", "reward")) + assert "reward" not in env.reward_keys + check_env_specs(env) + env = TransformedEnv(base_env, ExcludeTransform("a")) + assert "reward" in env.reward_keys + check_env_specs(env) + @pytest.mark.parametrize("nest_done", [True, False]) @pytest.mark.parametrize("nest_reward", [True, False]) def test_nested(self, nest_reward, nest_done): @@ -2189,8 +2591,9 @@ def test_nested(self, nest_reward, nest_done): td = transformed_env.rollout(1) td_keys = td.keys(True, True) assert ("next", env.reward_key) in td_keys - assert ("next", env.done_key) in td_keys - assert env.done_key in td_keys + for done_key in env.done_keys: + assert ("next", done_key) in td_keys + assert done_key in td_keys assert env.action_key in td_keys assert ("data", "states") in td_keys assert ("next", "data", "states") in td_keys @@ -2199,8 +2602,9 @@ def test_nested(self, nest_reward, nest_done): td = transformed_env.rollout(1) td_keys = td.keys(True, True) assert ("next", env.reward_key) in td_keys - assert ("next", env.done_key) in td_keys - assert env.done_key in td_keys + for done_key in env.done_keys: + assert ("next", done_key) in td_keys + assert done_key in td_keys assert env.action_key in td_keys assert ("data", "states") not in td_keys assert ("next", "data", "states") not in td_keys @@ -2384,6 +2788,38 @@ def test_transform_env(self): assert "b" in env.reset().keys() assert "c" in env.reset().keys() + @pytest.mark.parametrize("keep_done", [True, False]) + def test_select_done(self, keep_done): + base_env = TestExcludeTransform.EnvWithManyKeys() + env = TransformedEnv( + base_env, SelectTransform("b", "c", "done", keep_dones=keep_done) + ) + assert "done" in env.done_keys + check_env_specs(env) + env = TransformedEnv(base_env, SelectTransform("b", "c", keep_dones=keep_done)) + if keep_done: + assert "done" in env.done_keys + else: + assert "done" not in env.done_keys + check_env_specs(env) + + @pytest.mark.parametrize("keep_reward", [True, False]) + def test_select_reward(self, keep_reward): + base_env = TestExcludeTransform.EnvWithManyKeys() + env = TransformedEnv( + base_env, SelectTransform("b", "c", "reward", keep_rewards=keep_reward) + ) + assert "reward" in env.reward_keys + check_env_specs(env) + env = TransformedEnv( + base_env, SelectTransform("b", "c", keep_rewards=keep_reward) + ) + if keep_reward: + assert "reward" in env.reward_keys + else: + assert "reward" not in env.reward_keys + check_env_specs(env) + @pytest.mark.parametrize("nest_done", [True, False]) @pytest.mark.parametrize("nest_reward", [True, False]) def test_nested(self, nest_reward, nest_done): @@ -2395,8 +2831,9 @@ def test_nested(self, nest_reward, nest_done): td = transformed_env.rollout(1) td_keys = td.keys(True, True) assert ("next", env.reward_key) in td_keys - assert ("next", env.done_key) in td_keys - assert env.done_key in td_keys + for done_key in env.done_keys: + assert ("next", done_key) in td_keys + assert done_key in td_keys assert env.action_key in td_keys assert ("data", "states") not in td_keys assert ("next", "data", "states") not in td_keys @@ -2405,8 +2842,9 @@ def test_nested(self, nest_reward, nest_done): td = transformed_env.rollout(1) td_keys = td.keys(True, True) assert ("next", env.reward_key) in td_keys - assert ("next", env.done_key) in td_keys - assert env.done_key in td_keys + for done_key in env.done_keys: + assert ("next", done_key) in td_keys + assert done_key in td_keys assert env.action_key in td_keys assert ("data", "states") in td_keys assert ("next", "data", "states") in td_keys @@ -2699,7 +3137,7 @@ def test_transform_no_env(self): with pytest.raises( RuntimeError, match="parent not found for FrameSkipTransform" ): - t._step(tensordict) + t._step(tensordict, tensordict.get("next")) def test_transform_compose(self): t = Compose(FrameSkipTransform(2)) @@ -2707,7 +3145,7 @@ def test_transform_compose(self): with pytest.raises( RuntimeError, match="parent not found for FrameSkipTransform" ): - t._step(tensordict) + t._step(tensordict, tensordict.get("next")) @pytest.mark.skipif(not _has_gym, reason="gym not installed") @pytest.mark.parametrize("skip", [-1, 1, 2, 3]) @@ -3023,7 +3461,8 @@ def test_transform_no_env(self): match="NoopResetEnv.parent not found. Make sure that the parent is set.", ): t.reset(TensorDict({"next": {}}, [])) - t._step(TensorDict({"next": {}}, [])) + td = TensorDict({"next": {}}, []) + t._step(td, td.get("next")) def test_transform_compose(self): t = Compose(NoopResetEnv()) @@ -3032,7 +3471,8 @@ def test_transform_compose(self): match="NoopResetEnv.parent not found. Make sure that the parent is set.", ): t.reset(TensorDict({"next": {}}, [])) - t._step(TensorDict({"next": {}}, [])) + td = TensorDict({"next": {}}, []) + t._step(td, td.get("next")) def test_transform_model(self): t = nn.Sequential(NoopResetEnv(), nn.Identity()) @@ -3105,6 +3545,20 @@ def test_noop_reset_env_error(self, random, device, compose): ): transformed_env.reset() + @pytest.mark.parametrize("noops", [0, 2, 8]) + @pytest.mark.parametrize("max_steps", [0, 5, 9]) + def test_noop_reset_limit_exceeded(self, noops, max_steps): + env = IncrementingEnv(max_steps=max_steps) + check_env_specs(env) + noop_reset_env = NoopResetEnv(noops=noops, random=False) + transformed_env = TransformedEnv(env, noop_reset_env) + if noops <= max_steps: # Normal behavior. + result = transformed_env.reset() + assert result["observation"] == noops + elif noops > max_steps: # Raise error as reset limit exceeded. + with pytest.raises(RuntimeError): + transformed_env.reset() + class TestObservationNorm(TransformBase): @pytest.mark.parametrize( @@ -3453,13 +3907,11 @@ def test_observationnorm( observation_spec = on.transform_observation_spec(observation_spec) for key in keys: if standard_normal: - assert (observation_spec[key].space.minimum == -loc / scale).all() - assert ( - observation_spec[key].space.maximum == (1 - loc) / scale - ).all() + assert (observation_spec[key].space.low == -loc / scale).all() + assert (observation_spec[key].space.high == (1 - loc) / scale).all() else: - assert (observation_spec[key].space.minimum == loc).all() - assert (observation_spec[key].space.maximum == scale + loc).all() + assert (observation_spec[key].space.low == loc).all() + assert (observation_spec[key].space.high == scale + loc).all() @pytest.mark.parametrize("keys", [["observation"], ["observation", "next_pixel"]]) @pytest.mark.parametrize("size", [1, 3]) @@ -3473,15 +3925,13 @@ def make_env(): base_env = ContinuousActionVecMockEnv( observation_spec=CompositeSpec( observation=BoundedTensorSpec( - minimum=1, maximum=1, shape=torch.Size([size]) + low=1, high=1, shape=torch.Size([size]) ), observation_orig=BoundedTensorSpec( - minimum=1, maximum=1, shape=torch.Size([size]) + low=1, high=1, shape=torch.Size([size]) ), ), - action_spec=BoundedTensorSpec( - minimum=1, maximum=1, shape=torch.Size((size,)) - ), + action_spec=BoundedTensorSpec(low=1, high=1, shape=torch.Size((size,))), seed=0, ) base_env.out_key = "observation" @@ -4032,30 +4482,85 @@ def test_transform_inverse(self): class TestRewardSum(TransformBase): def test_single_trans_env_check(self): - env = TransformedEnv(ContinuousActionVecMockEnv(), RewardSum()) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + Compose(RewardScaling(loc=-1, scale=1), RewardSum()), + ) check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 def test_serial_trans_env_check(self): def make_env(): - return TransformedEnv(ContinuousActionVecMockEnv(), RewardSum()) + return TransformedEnv( + ContinuousActionVecMockEnv(), + Compose(RewardScaling(loc=-1, scale=1), RewardSum()), + ) env = SerialEnv(2, make_env) check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 def test_parallel_trans_env_check(self): def make_env(): - return TransformedEnv(ContinuousActionVecMockEnv(), RewardSum()) + return TransformedEnv( + ContinuousActionVecMockEnv(), + Compose(RewardScaling(loc=-1, scale=1), RewardSum()), + ) env = ParallelEnv(2, make_env) check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 def test_trans_serial_env_check(self): - env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), RewardSum()) + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + Compose(RewardScaling(loc=-1, scale=1), RewardSum()), + ) check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 def test_trans_parallel_env_check(self): - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), RewardSum()) + env = TransformedEnv( + ParallelEnv(2, ContinuousActionVecMockEnv), + Compose(RewardScaling(loc=-1, scale=1), RewardSum()), + ) + check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 + + @pytest.mark.parametrize("has_in_keys,", [True, False]) + def test_trans_multi_key( + self, has_in_keys, n_workers=2, batch_size=(3, 2), max_steps=5 + ): + torch.manual_seed(0) + env_fun = lambda: MultiKeyCountingEnv(batch_size=batch_size) + base_env = SerialEnv(n_workers, env_fun) + if has_in_keys: + t = RewardSum(in_keys=base_env.reward_keys, reset_keys=base_env.reset_keys) + else: + t = RewardSum() + env = TransformedEnv( + base_env, + Compose(t), + ) + policy = MultiKeyCountingEnvPolicy( + full_action_spec=env.action_spec, deterministic=True + ) + check_env_specs(env) + td = env.rollout(max_steps, policy=policy) + for reward_key in env.reward_keys: + reward_key = _unravel_key_to_tuple(reward_key) + assert ( + td.get( + ("next", _replace_last(reward_key, f"episode_{reward_key[-1]}")) + )[(0,) * (len(batch_size) + 1)][-1] + == max_steps + ).all() @pytest.mark.parametrize("in_key", ["reward", ("some", "nested")]) def test_transform_no_env(self, in_key): @@ -4080,7 +4585,8 @@ def test_transform_no_env(self, in_key): def test_transform_compose( self, ): - t = Compose(RewardSum()) + # reset keys should not be needed for offline run + t = Compose(RewardSum(in_keys=["reward"], out_keys=["episode_reward"])) reward = torch.randn(10) td = TensorDict({("next", "reward"): reward}, []) with pytest.raises( @@ -4168,24 +4674,24 @@ def test_sum_reward(self, keys, device): ) # apply one time, episode_reward should be equal to reward again - td = rs._step(td) - td_next = td["next"] + td_next = rs._step(td, td.get("next")) assert "episode_reward" in td.keys() assert (td_next.get("episode_reward") == td_next.get("reward")).all() # apply a second time, episode_reward should twice the reward td["episode_reward"] = td["next", "episode_reward"] - td = rs._step(td) - td_next = td["next"] + td_next = rs._step(td, td.get("next")) assert (td_next.get("episode_reward") == 2 * td_next.get("reward")).all() # reset environments td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device)) + with pytest.raises(TypeError, match="reset_keys not provided but parent"): + rs.reset(td) + rs._reset_keys = ["_reset"] rs.reset(td) # apply a third time, episode_reward should be equal to reward again - td = rs._step(td) - td_next = td["next"] + td_next = rs._step(td, td.get("next")) assert (td_next.get("episode_reward") == td_next.get("reward")).all() # test transform_observation_spec @@ -5055,7 +5561,9 @@ def test_transform_compose(self, batch, mode, device): batch_size=batch, ) td = t.reset(td) - td = t._step(td) + next_td = td.get("next") + next_td = t._step(td, next_td) + td.set("next", next_td) if mode == "reduce": assert (td["next", "target_return"] + td["next", "reward"] == 10.0).all() @@ -5129,10 +5637,11 @@ def test_transform_no_env(self, mode, in_key, out_key): t = TargetReturn( target_return=10.0, mode=mode, in_keys=[in_key], out_keys=[out_key] ) - reward = torch.randn(10) - td = TensorDict({("next", in_key): reward}, []) + reward = torch.randn(10, 1) + td = TensorDict({("next", in_key): reward}, [10]) td = t.reset(td) - td = t._step(td) + td_next = t._step(td, td.get("next")) + td.set("next", td_next) if mode == "reduce": assert (td["next", out_key] + td["next", in_key] == 10.0).all() else: @@ -5143,8 +5652,8 @@ def test_transform_model( ): t = TargetReturn(target_return=10.0) model = nn.Sequential(t, nn.Identity()) - reward = torch.randn(10) - td = TensorDict({("next", "reward"): reward}, []) + reward = torch.randn(10, 1) + td = TensorDict({("next", "reward"): reward}, [10]) with pytest.raises( NotImplementedError, match="cannot be executed without a parent" ): @@ -5157,8 +5666,8 @@ def test_transform_rb( ): t = TargetReturn(target_return=10.0) rb = rbclass(storage=LazyTensorStorage(10)) - reward = torch.randn(10) - td = TensorDict({("next", "reward"): reward}, []).expand(10) + reward = torch.randn(10, 1) + td = TensorDict({("next", "reward"): reward}, [10]) rb.append_transform(t) rb.extend(td) with pytest.raises( @@ -5214,8 +5723,8 @@ def test_transform_no_env(self, keys, batch, device): ) for key in keys: assert observation_spec[key].shape == torch.Size([3, 16, 16]) - assert (observation_spec[key].space.minimum == 0).all() - assert (observation_spec[key].space.maximum == 1).all() + assert (observation_spec[key].space.low == 0).all() + assert (observation_spec[key].space.high == 1).all() @pytest.mark.parametrize("batch", [[], [1], [3, 2]]) @pytest.mark.parametrize( @@ -5263,8 +5772,8 @@ def test_transform_compose(self, keys, batch, device): ) for key in keys: assert observation_spec[key].shape == torch.Size([3, 16, 16]) - assert (observation_spec[key].space.minimum == 0).all() - assert (observation_spec[key].space.maximum == 1).all() + assert (observation_spec[key].space.low == 0).all() + assert (observation_spec[key].space.high == 1).all() @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_single_trans_env_check(self, out_keys): @@ -6049,13 +6558,13 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device): transformed_env = TransformedEnv(base_env, vip) td = transformed_env.reset() assert td.device == device - exp_keys = {"vec", "done", "pixels_orig"} + expected_keys = {"vec", "done", "pixels_orig", "terminated"} if tensor_pixels_key: - exp_keys.add(tensor_pixels_key[0]) - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + expected_keys.add(tensor_pixels_key[0]) + assert set(td.keys()) == expected_keys, set(td.keys()) - expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), @@ -6063,11 +6572,12 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device): "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), } ) if tensor_pixels_key: - exp_keys.add(("next", tensor_pixels_key[0])) - assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys + expected_keys.add(("next", tensor_pixels_key[0])) + assert set(td.keys(True)) == expected_keys, set(td.keys(True)) - expected_keys transformed_env.close() @pytest.mark.parametrize("stack_images", [True, False]) @@ -6102,20 +6612,20 @@ def base_env_constructor(): td = transformed_env.reset() assert td.device == device if stack_images: - exp_keys = {"pixels_orig", "done", "vec"} + expected_keys = {"pixels_orig", "done", "vec", "terminated"} # assert td["vec"].shape[0] == 2 assert td["vec"].ndimension() == 1 + parallel - assert set(td.keys()) == exp_keys + assert set(td.keys()) == expected_keys else: - exp_keys = {"pixels_orig", "done", "vec", "vec2"} + expected_keys = {"pixels_orig", "done", "vec", "vec2", "terminated"} assert td["vec"].shape[0 + parallel] != 2 assert td["vec"].ndimension() == 1 + parallel assert td["vec2"].shape[0 + parallel] != 2 assert td["vec2"].ndimension() == 1 + parallel - assert set(td.keys()) == exp_keys + assert set(td.keys()) == expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), @@ -6123,11 +6633,12 @@ def base_env_constructor(): "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), } ) if not stack_images: - exp_keys.add(("next", "vec2")) - assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys + expected_keys.add(("next", "vec2")) + assert set(td.keys(True)) == expected_keys, set(td.keys(True)) - expected_keys transformed_env.close() def test_transform_env(self, model, device): @@ -6145,13 +6656,13 @@ def test_transform_env(self, model, device): td = transformed_env.reset() assert td.device == device assert td.batch_size == torch.Size([4]) - exp_keys = {"vec", "done", "pixels_orig"} + expected_keys = {"vec", "done", "pixels_orig", "terminated"} if tensor_pixels_key: - exp_keys.add(tensor_pixels_key) - assert set(td.keys()) == exp_keys + expected_keys.add(tensor_pixels_key) + assert set(td.keys()) == expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), @@ -6159,9 +6670,10 @@ def test_transform_env(self, model, device): "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), } ) - assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys + assert set(td.keys(True)) == expected_keys, set(td.keys(True)) - expected_keys transformed_env.close() del transformed_env @@ -6197,19 +6709,20 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa td = transformed_env.reset(tensordict_reset) assert td.device == device assert td.batch_size == torch.Size([4]) - exp_keys = { + expected_keys = { "vec", "done", "pixels_orig", "goal_embedding", "goal_image", + "terminated", } if tensor_pixels_key: - exp_keys.add(tensor_pixels_key) - assert set(td.keys()) == exp_keys + expected_keys.add(tensor_pixels_key) + assert set(td.keys()) == expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "vec"), ("next", "pixels_orig"), @@ -6217,9 +6730,10 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa "action", ("next", "reward"), ("next", "done"), + ("next", "terminated"), } ) - assert set(td.keys(True)) == exp_keys, td + assert set(td.keys(True)) == expected_keys, td torch.manual_seed(1) tensordict_reset = TensorDict( @@ -6230,7 +6744,7 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa td = transformed_env.rollout( 5, auto_reset=False, tensordict=transformed_env.reset(tensordict_reset) ) - assert set(td.keys(True)) == exp_keys, td + assert set(td.keys(True)) == expected_keys, td # test that we do compute the reward we want cur_embedding = td["next", "vec"] goal_embedding = td["goal_embedding"] @@ -6245,8 +6759,8 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa with pytest.raises(AssertionError): torch.testing.assert_close(cur_embedding[:, 1:], last_embedding[:, :-1]) - explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - ( - -torch.norm(last_embedding - goal_embedding, dim=-1) + explicit_reward = -torch.linalg.norm(cur_embedding - goal_embedding, dim=-1) - ( + -torch.linalg.norm(last_embedding - goal_embedding, dim=-1) ) torch.testing.assert_close(explicit_reward, td["next", "reward"].squeeze()) @@ -6315,7 +6829,14 @@ def test_vip_spec_against_real(self, model, tensor_pixels_key, device): + ["action"] + list(transformed_env.observation_spec.keys()) + [("next", key) for key in transformed_env.observation_spec.keys()] - + [("next", "reward"), ("next", "done"), "done", "next"] + + [ + ("next", "reward"), + ("next", "done"), + "done", + ("next", "terminated"), + "terminated", + "next", + ] ) assert set(expected_keys) == set(transformed_env.rollout(3).keys(True)) @@ -6486,13 +7007,13 @@ def test_vc1_instantiation(self, del_keys, device): transformed_env = TransformedEnv(base_env, vc1) td = transformed_env.reset() assert td.device == device - exp_keys = {"nested", "done", "pixels_orig"} + expected_keys = {"nested", "done", "pixels_orig", "terminated"} if not del_keys: - exp_keys.add("pixels") - assert set(td.keys()) == exp_keys, set(td.keys()) - exp_keys + expected_keys.add("pixels") + assert set(td.keys()) == expected_keys, set(td.keys()) - expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "nested"), ("next", "nested", "vec"), @@ -6502,11 +7023,12 @@ def test_vc1_instantiation(self, del_keys, device): ("nested", "vec"), ("next", "reward"), ("next", "done"), + ("next", "terminated"), } ) if not del_keys: - exp_keys.add(("next", "pixels")) - assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys + expected_keys.add(("next", "pixels")) + assert set(td.keys(True)) == expected_keys, set(td.keys(True)) - expected_keys transformed_env.close() @pytest.mark.parametrize("del_keys", [True, False]) @@ -6525,13 +7047,13 @@ def test_transform_env(self, device, del_keys): td = transformed_env.reset() assert td.device == device assert td.batch_size == torch.Size([4]) - exp_keys = {"nested", "done", "pixels_orig"} + expected_keys = {"nested", "done", "pixels_orig", "terminated"} if not del_keys: - exp_keys.add("pixels") - assert set(td.keys()) == exp_keys + expected_keys.add("pixels") + assert set(td.keys()) == expected_keys td = transformed_env.rand_step(td) - exp_keys = exp_keys.union( + expected_keys = expected_keys.union( { ("next", "nested"), ("next", "nested", "vec"), @@ -6541,11 +7063,12 @@ def test_transform_env(self, device, del_keys): ("nested", "vec"), ("next", "reward"), ("next", "done"), + ("next", "terminated"), } ) if not del_keys: - exp_keys.add(("next", "pixels")) - assert set(td.keys(True)) == exp_keys, set(td.keys(True)) - exp_keys + expected_keys.add(("next", "pixels")) + assert set(td.keys(True)) == expected_keys, set(td.keys(True)) - expected_keys transformed_env.close() del transformed_env @@ -6569,7 +7092,14 @@ def test_vc1_spec_against_real(self, del_keys, device): unravel_key(("next", key)) for key in transformed_env.observation_spec.keys(True) ] - + [("next", "reward"), ("next", "done"), "done", "next"] + + [ + ("next", "reward"), + ("next", "done"), + "done", + ("next", "terminated"), + "terminated", + "next", + ] ) assert set(expected_keys) == set(transformed_env.rollout(3).keys(True)) @@ -6586,6 +7116,7 @@ def _test_vecnorm_subproc_auto( tensordict = env.reset() for _ in range(10): tensordict = env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put(True) msg = queue_in.get(timeout=TIMEOUT) assert msg == "all_done" @@ -6693,17 +7224,24 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): assert msg == "start" for _ in range(10): tensordict = parallel_env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put("first round") msg = queue_in.get(timeout=TIMEOUT) assert msg == "start" for _ in range(10): tensordict = parallel_env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put("second round") parallel_env.close() queue_out.close() queue_in.close() del parallel_env, queue_out, queue_in + @pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Nested spawned multiprocessed is currently failing in python 3.11. " + "See https://github.com/python/cpython/pull/108568 for info and fix.", + ) def test_parallelenv_vecnorm(self): if _has_gym: make_env = EnvCreator( @@ -6777,6 +7315,7 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): for _ in range(N): td = env_t.rand_step(td) tds.append(td.clone()) + td = step_mdp(td) if td.get("done").any(): td = env_t.reset() tds = torch.stack(tds, 0) @@ -6790,6 +7329,15 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): env_t.close() self.SEED = 0 + def test_pickable(self): + + transform = VecNorm() + serialized = pickle.dumps(transform) + transform2 = pickle.loads(serialized) + assert transform.__dict__.keys() == transform2.__dict__.keys() + for key in sorted(transform.__dict__.keys()): + assert isinstance(transform.__dict__[key], type(transform2.__dict__[key])) + def test_added_transforms_are_in_eval_mode_trivial(): base_env = ContinuousActionVecMockEnv() @@ -6817,7 +7365,7 @@ def test_added_transforms_are_in_eval_mode(): class TestTransformedEnv: def test_independent_obs_specs_from_shared_env(self): obs_spec = CompositeSpec( - observation=BoundedTensorSpec(minimum=0, maximum=10, shape=torch.Size((1,))) + observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,))) ) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) t1 = TransformedEnv(base_env, transform=ObservationNorm(loc=3, scale=2)) @@ -6826,14 +7374,14 @@ def test_independent_obs_specs_from_shared_env(self): t1_obs_spec = t1.observation_spec t2_obs_spec = t2.observation_spec - assert t1_obs_spec["observation"].space.minimum == 3 - assert t1_obs_spec["observation"].space.maximum == 23 + assert t1_obs_spec["observation"].space.low == 3 + assert t1_obs_spec["observation"].space.high == 23 - assert t2_obs_spec["observation"].space.minimum == 1 - assert t2_obs_spec["observation"].space.maximum == 61 + assert t2_obs_spec["observation"].space.low == 1 + assert t2_obs_spec["observation"].space.high == 61 - assert base_env.observation_spec["observation"].space.minimum == 0 - assert base_env.observation_spec["observation"].space.maximum == 10 + assert base_env.observation_spec["observation"].space.low == 0 + assert base_env.observation_spec["observation"].space.high == 10 def test_independent_reward_specs_from_shared_env(self): reward_spec = UnboundedContinuousTensorSpec() @@ -6944,7 +7492,7 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): dim=-3, ) t2 = FiniteTensorDictCheck() - t3 = StepCounter() + t3 = ExcludeTransform() compose = Compose(t1, t2, t3) dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) td = TensorDict( @@ -6962,10 +7510,6 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): match="CatFrames cannot process unbatched tensordict instances", ): compose(td.clone(False)) - with pytest.raises( - NotImplementedError, match="StepCounter cannot be called independently" - ): - compose[1:](td.clone(False)) compose._call(td) for key in keys: assert td.get(key).shape[-3] == nchannels * N @@ -6987,20 +7531,8 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): ) @pytest.mark.parametrize("device", get_default_devices()) - @pytest.mark.parametrize( - "keys_inv_1", - [ - ["action_1"], - [], - ], - ) - @pytest.mark.parametrize( - "keys_inv_2", - [ - ["action_2"], - [], - ], - ) + @pytest.mark.parametrize("keys_inv_1", [["action_1"], []]) + @pytest.mark.parametrize("keys_inv_2", [["action_2"], []]) def test_compose_inv(self, keys_inv_1, keys_inv_2, device): torch.manual_seed(0) keys_to_transform = set(keys_inv_1 + keys_inv_2) @@ -7101,11 +7633,11 @@ def test_insert(self): _ = env.reward_spec assert env._input_spec is not None - assert "_action_spec" in env._input_spec - assert env._input_spec["_action_spec"] is not None - assert env._output_spec["_observation_spec"] is not None - assert env._output_spec["_reward_spec"] is not None - assert env._output_spec["_done_spec"] is not None + assert "full_action_spec" in env._input_spec + assert env._input_spec["full_action_spec"] is not None + assert env._output_spec["full_observation_spec"] is not None + assert env._output_spec["full_reward_spec"] is not None + assert env._output_spec["full_done_spec"] is not None env.insert_transform(0, CatFrames(N=4, dim=-1, in_keys=[key])) @@ -7409,9 +7941,7 @@ def test_single_trans_env_check(self, create_copy, compose): ["observation_orig"], ["stuff"], ["observation_orig"], - [ - "stuff", - ], + ["stuff"], create_copy=create_copy, ) if compose: @@ -7424,12 +7954,8 @@ def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), RenameTransform( - [ - "observation", - ], - [ - "stuff", - ], + ["observation"], + ["stuff"], create_copy=create_copy, ), ) @@ -7444,9 +7970,7 @@ def make_env(): ["observation_orig"], ["stuff"], ["observation_orig"], - [ - "stuff", - ], + ["stuff"], create_copy=create_copy, ), ) @@ -7459,12 +7983,8 @@ def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), RenameTransform( - [ - "observation", - ], - [ - "stuff", - ], + ["observation"], + ["stuff"], create_copy=create_copy, ), ) @@ -7479,9 +7999,7 @@ def make_env(): ["observation_orig"], ["stuff"], ["observation_orig"], - [ - "stuff", - ], + ["stuff"], create_copy=create_copy, ), ) @@ -7496,12 +8014,8 @@ def make_env(): env = TransformedEnv( SerialEnv(2, make_env), RenameTransform( - [ - "observation", - ], - [ - "stuff", - ], + ["observation"], + ["stuff"], create_copy=create_copy, ), ) @@ -7512,9 +8026,7 @@ def make_env(): ["observation_orig"], ["stuff"], ["observation_orig"], - [ - "stuff", - ], + ["stuff"], create_copy=create_copy, ), ) @@ -7527,12 +8039,8 @@ def make_env(): env = TransformedEnv( ParallelEnv(2, make_env), RenameTransform( - [ - "observation", - ], - [ - "stuff", - ], + ["observation"], + ["stuff"], create_copy=create_copy, ), ) @@ -7543,9 +8051,7 @@ def make_env(): ["observation_orig"], ["stuff"], ["observation_orig"], - [ - "stuff", - ], + ["stuff"], create_copy=create_copy, ), ) @@ -7597,12 +8103,8 @@ def test_transform_env(self, create_copy): env = TransformedEnv( ContinuousActionVecMockEnv(), RenameTransform( - [ - "observation", - ], - [ - "stuff", - ], + ["observation"], + ["stuff"], create_copy=create_copy, ), ) @@ -7622,9 +8124,7 @@ def test_transform_env(self, create_copy): ["observation_orig"], ["stuff"], ["observation_orig"], - [ - "stuff", - ], + ["stuff"], create_copy=create_copy, ), ) @@ -7638,6 +8138,28 @@ def test_transform_env(self, create_copy): assert "stuff" in r.keys() assert ("next", "stuff") in r.keys(True) + def test_rename_done_reward(self, create_copy): + env = TransformedEnv( + ContinuousActionVecMockEnv(), + RenameTransform( + ["done"], + [("nested", "other_done")], + create_copy=create_copy, + ), + ) + assert ("nested", "other_done") in env.done_keys + check_env_specs(env) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + RenameTransform( + ["reward"], + [("nested", "reward")], + create_copy=create_copy, + ), + ) + assert ("nested", "reward") in env.reward_keys + check_env_specs(env) + def test_transform_model(self, create_copy): t = RenameTransform(["a"], ["b"], create_copy=create_copy) tensordict = TensorDict({"a": torch.randn(())}, []) @@ -7691,6 +8213,26 @@ def test_transform_inverse(self, create_copy): class TestInitTracker(TransformBase): + @pytest.mark.skipif(not _has_gym, reason="no gym detected") + def test_init_gym( + self, + ): + env = TransformedEnv( + GymEnv(PENDULUM_VERSIONED), + Compose(StepCounter(max_steps=30), InitTracker()), + ) + env.rollout(1000) + check_env_specs(env) + + @pytest.mark.skipif(not _has_dm_control, reason="no dm_control detected") + def test_init_dmc(self): + env = TransformedEnv( + DMControlEnv("cheetah", "run"), + Compose(StepCounter(max_steps=30), InitTracker()), + ) + env.rollout(1000) + check_env_specs(env) + def test_single_trans_env_check(self): env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2]) env = TransformedEnv(env, InitTracker()) @@ -7733,6 +8275,8 @@ def make_env(): check_env_specs(env) def test_transform_no_env(self): + with pytest.raises(ValueError, match="init_key can only be of type str"): + InitTracker(init_key=("some", "nested")) with pytest.raises( NotImplementedError, match="InitTracker cannot be executed without a parent" ): @@ -7783,7 +8327,7 @@ def test_transform_rb(self, rbclass): def test_transform_inverse(self): raise pytest.skip("No inverse for InitTracker") - @pytest.mark.parametrize("init_key", ["is_init", "loool", ("other", "key")]) + @pytest.mark.parametrize("init_key", ["is_init", "loool"]) @pytest.mark.parametrize("nested_done", [True, False]) @pytest.mark.parametrize("max_steps", [5]) def test_nested( @@ -7804,6 +8348,7 @@ def test_nested( env, InitTracker(init_key=init_key), ) + init_key = transformed_env.transform.init_keys[0] td = transformed_env.rollout( rollout_length, policy=policy, break_when_any_done=False ) @@ -7820,15 +8365,18 @@ def test_nested( assert torch.all(is_init[max_steps + 1] == 1) assert torch.all(is_init[max_steps + 2 :] == 0) - _reset = env.done_spec.rand() - td_reset = transformed_env.reset( - TensorDict( - {"_reset": _reset}, - batch_size=env.batch_size, - device=env.device, - ) + td_reset = TensorDict( + rand_reset(transformed_env), + batch_size=env.batch_size, + device=env.device, ) - assert (td_reset[init_key] == _reset).all() + if nested_done: + reset = td_reset["data", "_reset"] + else: + reset = td_reset["_reset"] + + td_reset = transformed_env.reset(td_reset) + assert (td_reset[init_key] == reset).all() class TestKLRewardTransform(TransformBase): @@ -7884,12 +8432,13 @@ def test_transform_no_env(self, in_key, out_key): { "action": torch.randn(*batch, 7), "observation": torch.randn(*batch, 7), - "next": {t.in_keys[0]: torch.zeros(*batch, 1)}, "sample_log_prob": torch.randn(*batch), }, batch, ) - t._step(tensordict) + next_td = TensorDict({t.in_keys[0]: torch.zeros(*batch, 1)}, batch) + next_td = t._step(tensordict, next_td) + tensordict.set("next", next_td) assert (tensordict.get("next").get(t.out_keys[0]) != 0).all() def test_transform_compose(self): @@ -8027,6 +8576,497 @@ def test_transform_rb(self, rbclass): def test_transform_inverse(self): raise pytest.skip("No inverse for KLRewardTransform") + @pytest.mark.parametrize("requires_grad", [True, False]) + def test_kl_diff(self, requires_grad): + actor = self._make_actor() + t = KLRewardTransform( + actor, in_keys="reward", out_keys="reward", requires_grad=requires_grad + ) + assert t.frozen_params.requires_grad is requires_grad + + def test_kl_lstm(self): + from tensordict.nn import ( + NormalParamExtractor, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModule, + ) + + env = TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + lstm_module = LSTMModule( + input_size=env.observation_spec["observation"].shape[-1], + hidden_size=2, + in_keys=["observation", "rs_h", "rs_c"], + out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")], + ) + mlp = MLP(num_cells=[2], out_features=env.action_spec.shape[-1] * 2) + policy = ProbabilisticTensorDictSequential( + lstm_module, + TensorDictModule(mlp, in_keys=["intermediate"], out_keys=["intermediate"]), + TensorDictModule( + NormalParamExtractor(), + in_keys=["intermediate"], + out_keys=["loc", "scale"], + ), + ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + return_log_prob=True, + ), + ) + policy(env.reset()) + klt = KLRewardTransform(policy) + # check that this runs: it can only run if the params are nn.Parameter instances + klt(env.rollout(3, policy)) + + +class TestActionMask(TransformBase): + @property + def _env_class(self): + from torchrl.data import BinaryDiscreteTensorSpec, DiscreteTensorSpec + + class MaskedEnv(EnvBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.action_spec = DiscreteTensorSpec(4) + self.state_spec = CompositeSpec( + action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool) + ) + self.observation_spec = CompositeSpec( + obs=UnboundedContinuousTensorSpec(3), + action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool), + ) + self.reward_spec = UnboundedContinuousTensorSpec(1) + + def _reset(self, tensordict): + td = self.observation_spec.rand() + td.update(torch.ones_like(self.state_spec.rand())) + return td + + def _step(self, data): + td = self.observation_spec.rand() + mask = data.get("action_mask") + action = data.get("action") + mask = mask.scatter(-1, action.unsqueeze(-1), 0) + + td.set("action_mask", mask) + td.set("reward", self.reward_spec.rand()) + td.set("done", ~(mask.any().view(1))) + return td + + def _set_seed(self, seed): + return seed + + return MaskedEnv + + def test_single_trans_env_check(self): + env = self._env_class() + env = TransformedEnv(env, ActionMask()) + check_env_specs(env) + + def test_serial_trans_env_check(self): + env = SerialEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) + check_env_specs(env) + + def test_parallel_trans_env_check(self): + env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) + check_env_specs(env) + + def test_trans_serial_env_check(self): + env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask()) + check_env_specs(env) + + def test_trans_parallel_env_check(self): + env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask()) + check_env_specs(env) + + def test_transform_no_env(self): + t = ActionMask() + with pytest.raises(RuntimeError, match="parent cannot be None"): + t._call(TensorDict({}, [])) + + def test_transform_compose(self): + env = self._env_class() + env = TransformedEnv(env, Compose(ActionMask())) + check_env_specs(env) + + def test_transform_env(self): + env = TransformedEnv(ContinuousActionVecMockEnv(), ActionMask()) + with pytest.raises(ValueError, match="The action spec must be one of"): + env.rollout(2) + env = self._env_class() + env = TransformedEnv(env, ActionMask()) + td = env.reset() + for _ in range(1000): + td = env.rand_action(td) + assert env.action_spec.is_in(td.get("action")) + td = env.step(td) + td = step_mdp(td) + if td.get("done"): + break + else: + raise RuntimeError + assert not td.get("action_mask").any() + + def test_transform_model(self): + t = ActionMask() + with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))): + t(TensorDict({}, [])) + + def test_transform_rb(self): + t = ActionMask() + rb = ReplayBuffer(storage=LazyTensorStorage(100)) + rb.append_transform(t) + rb.extend(TensorDict({"a": [1]}, [1]).expand(10)) + with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))): + rb.sample(3) + + def test_transform_inverse(self): + # no inverse transform + return + + +class TestDeviceCastTransform(TransformBase): + def test_single_trans_env_check(self): + env = ContinuousActionVecMockEnv(device="cpu:0") + env = TransformedEnv(env, DeviceCastTransform("cpu:1")) + assert env.device == torch.device("cpu:1") + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1") + ) + + env = SerialEnv(2, make_env) + assert env.device == torch.device("cpu:1") + check_env_specs(env) + + def test_parallel_trans_env_check(self): + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1") + ) + + env = ParallelEnv(2, make_env) + assert env.device == torch.device("cpu:1") + check_env_specs(env) + + def test_trans_serial_env_check(self): + def make_env(): + return ContinuousActionVecMockEnv(device="cpu:0") + + env = TransformedEnv(SerialEnv(2, make_env), DeviceCastTransform("cpu:1")) + assert env.device == torch.device("cpu:1") + check_env_specs(env) + + def test_trans_parallel_env_check(self): + def make_env(): + return ContinuousActionVecMockEnv(device="cpu:0") + + env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1")) + assert env.device == torch.device("cpu:1") + check_env_specs(env) + + def test_transform_no_env(self): + t = DeviceCastTransform("cpu:1", "cpu:0") + assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device( + "cpu:1" + ) + + def test_transform_compose(self): + t = Compose(DeviceCastTransform("cpu:1", "cpu:0")) + assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device( + "cpu:1" + ) + assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device( + "cpu:0" + ) + + def test_transform_env(self): + env = ContinuousActionVecMockEnv(device="cpu:0") + assert env.device == torch.device("cpu:0") + env = TransformedEnv(env, DeviceCastTransform("cpu:1")) + assert env.device == torch.device("cpu:1") + assert env.transform.device == torch.device("cpu:1") + assert env.transform.orig_device == torch.device("cpu:0") + + def test_transform_model(self): + t = Compose(DeviceCastTransform("cpu:1", "cpu:0")) + m = nn.Sequential(t) + assert t(TensorDict({}, [], device="cpu:0")).device == torch.device("cpu:1") + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + @pytest.mark.parametrize( + "storage", [TensorStorage, LazyTensorStorage, LazyMemmapStorage] + ) + def test_transform_rb(self, rbclass, storage): + t = Compose(DeviceCastTransform("cpu:1", "cpu:0")) + storage_kwargs = ( + { + "storage": TensorDict( + {"a": torch.zeros(20, 1, device="cpu:0")}, [20], device="cpu:0" + ) + } + if storage is TensorStorage + else {} + ) + rb = rbclass(storage=storage(max_size=20, device="auto", **storage_kwargs)) + rb.append_transform(t) + rb.add(TensorDict({"a": [1]}, [], device="cpu:1")) + assert rb._storage._storage.device == torch.device("cpu:0") + assert rb.sample(4).device == torch.device("cpu:1") + + def test_transform_inverse(self): + t = DeviceCastTransform("cpu:1", "cpu:0") + assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device( + "cpu:0" + ) + + +class TestPermuteTransform(TransformBase): + envclass = DiscreteActionConvMockEnv + + @classmethod + def _get_permute(cls): + return PermuteTransform( + (-1, -2, -3), in_keys=["pixels_orig", "pixels"], in_keys_inv=["pixels_orig"] + ) + + def test_single_trans_env_check(self): + base_env = TestPermuteTransform.envclass() + env = TransformedEnv(base_env, TestPermuteTransform._get_permute()) + check_env_specs(env) + assert env.observation_spec["pixels"] == env.observation_spec["pixels_orig"] + assert env.state_spec["pixels_orig"] == env.observation_spec["pixels_orig"] + + def test_serial_trans_env_check(self): + env = SerialEnv( + 2, + lambda: TransformedEnv( + TestPermuteTransform.envclass(), TestPermuteTransform._get_permute() + ), + ) + check_env_specs(env) + + def test_parallel_trans_env_check(self): + env = ParallelEnv( + 2, + lambda: TransformedEnv( + TestPermuteTransform.envclass(), TestPermuteTransform._get_permute() + ), + ) + check_env_specs(env) + + def test_trans_serial_env_check(self): + env = TransformedEnv( + SerialEnv(2, TestPermuteTransform.envclass), + TestPermuteTransform._get_permute(), + ) + check_env_specs(env) + + def test_trans_parallel_env_check(self): + env = TransformedEnv( + ParallelEnv(2, TestPermuteTransform.envclass), + TestPermuteTransform._get_permute(), + ) + check_env_specs(env) + + @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) + def test_transform_compose(self, batch): + D, W, H, C = 8, 32, 64, 3 + trans = Compose( + PermuteTransform( + dims=(-1, -4, -2, -3), + in_keys=["pixels"], + ) + ) # DxWxHxC => CxDxHxW + td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch) + td = trans(td) + assert td["pixels"].shape == torch.Size((*batch, C, D, H, W)) + + def test_transform_env(self): + base_env = TestPermuteTransform.envclass() + env = TransformedEnv(base_env, TestPermuteTransform._get_permute()) + check_env_specs(env) + assert env.observation_spec["pixels"] == env.observation_spec["pixels_orig"] + assert env.state_spec["pixels_orig"] == env.observation_spec["pixels_orig"] + assert env.state_spec["pixels_orig"] != base_env.state_spec["pixels_orig"] + assert env.observation_spec["pixels"] != base_env.observation_spec["pixels"] + + td = env.rollout(3) + assert td["pixels"].shape == torch.Size([3, 7, 7, 1]) + + # check error + with pytest.raises(ValueError, match="Only tailing dims with negative"): + t = PermuteTransform((-1, -10)) + + def test_transform_model(self): + batch = [2] + D, W, H, C = 8, 32, 64, 3 + trans = PermuteTransform( + dims=(-1, -4, -2, -3), + in_keys=["pixels"], + ) # DxWxHxC => CxDxHxW + td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch) + out_channels = 4 + from tensordict.nn import TensorDictModule + + model = nn.Sequential( + trans, + TensorDictModule( + nn.Conv3d(C, out_channels, 3, padding=1), + in_keys=["pixels"], + out_keys=["pixels"], + ), + ) + td = model(td) + assert td["pixels"].shape == torch.Size((*batch, out_channels, D, H, W)) + + def test_transform_rb(self): + batch = [6] + D, W, H, C = 4, 5, 6, 3 + trans = PermuteTransform( + dims=(-1, -4, -2, -3), + in_keys=["pixels"], + ) # DxWxHxC => CxDxHxW + td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch) + rb = TensorDictReplayBuffer(storage=LazyTensorStorage(5), transform=trans) + rb.extend(td) + sample = rb.sample(2) + assert sample["pixels"].shape == torch.Size([2, C, D, H, W]) + + @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) + def test_transform_inverse(self, batch): + D, W, H, C = 8, 32, 64, 3 + trans = PermuteTransform( + dims=(-1, -4, -2, -3), + in_keys_inv=["pixels"], + ) # DxWxHxC => CxDxHxW + td = TensorDict({"pixels": torch.randn((*batch, C, D, H, W))}, batch_size=batch) + td = trans.inv(td) + assert td["pixels"].shape == torch.Size((*batch, D, W, H, C)) + + @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) + def test_transform_no_env(self, batch): + D, W, H, C = 8, 32, 64, 3 + trans = PermuteTransform( + dims=(-1, -4, -2, -3), + in_keys=["pixels"], + ) # DxWxHxC => CxDxHxW + td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch) + td = trans(td) + assert td["pixels"].shape == torch.Size((*batch, C, D, H, W)) + + +@pytest.mark.skipif( + not _has_gym, reason="EndOfLifeTransform can only be tested when Gym is present." +) +class TestEndOfLife(TransformBase): + def test_trans_parallel_env_check(self): + def make(): + with set_gym_backend("gymnasium"): + return GymEnv("ALE/Breakout-v5") + + with pytest.warns(UserWarning, match="The base_env is not a gym env"): + with pytest.raises(AttributeError): + env = TransformedEnv( + ParallelEnv(2, make), transform=EndOfLifeTransform() + ) + check_env_specs(env) + + def test_trans_serial_env_check(self): + def make(): + with set_gym_backend("gymnasium"): + return GymEnv("ALE/Breakout-v5") + + with pytest.warns(UserWarning, match="The base_env is not a gym env"): + env = TransformedEnv(SerialEnv(2, make), transform=EndOfLifeTransform()) + check_env_specs(env) + + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) + def test_single_trans_env_check(self, eol_key, lives_key): + with set_gym_backend("gymnasium"): + env = TransformedEnv( + GymEnv("ALE/Breakout-v5"), + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), + ) + check_env_specs(env) + + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) + def test_serial_trans_env_check(self, eol_key, lives_key): + def make(): + with set_gym_backend("gymnasium"): + return TransformedEnv( + GymEnv("ALE/Breakout-v5"), + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), + ) + + env = SerialEnv(2, make) + check_env_specs(env) + + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) + def test_parallel_trans_env_check(self, eol_key, lives_key): + def make(): + with set_gym_backend("gymnasium"): + return TransformedEnv( + GymEnv("ALE/Breakout-v5"), + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), + ) + + env = ParallelEnv(2, make) + check_env_specs(env) + + def test_transform_no_env(self): + t = EndOfLifeTransform() + with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))): + t._step(TensorDict({}, []), TensorDict({}, [])) + + def test_transform_compose(self): + t = EndOfLifeTransform() + with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))): + Compose(t)._step(TensorDict({}, []), TensorDict({}, [])) + + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) + def test_transform_env(self, eol_key, lives_key): + from tensordict.nn import TensorDictModule + from torchrl.objectives import DQNLoss + from torchrl.objectives.value import GAE + + with set_gym_backend("gymnasium"): + env = TransformedEnv( + GymEnv("ALE/Breakout-v5"), + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), + ) + check_env_specs(env) + loss = DQNLoss(nn.Identity(), action_space="categorical") + env.transform.register_keys(loss) + assert ("next", eol_key) in loss.in_keys + gae = GAE( + gamma=0.9, + lmbda=0.9, + value_network=TensorDictModule(nn.Identity(), ["x"], ["y"]), + ) + env.transform.register_keys(gae) + assert ("next", eol_key) in gae.in_keys + + def test_transform_model(self): + t = EndOfLifeTransform() + with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))): + nn.Sequential(t)(TensorDict({}, [])) + + def test_transform_rb(self): + pass + + def test_transform_inverse(self): + pass + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/test/test_utils.py b/test/test_utils.py index 6a44226d780..99653ae4d36 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,6 +5,7 @@ import argparse import os import sys +from copy import copy from importlib import import_module from unittest import mock @@ -12,7 +13,8 @@ import pytest from torchrl._utils import get_binary_env_var, implement_for -from torchrl.envs.libs.gym import gym_backend, set_gym_backend + +from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @pytest.mark.parametrize("value", ["True", "1", "true"]) @@ -135,24 +137,27 @@ def test_implement_for(): def test_implement_for_missing_module(): - msg = r"Supported version of 'missing_module' has not been found." + msg = r"Supported version of 'test_utils.implement_for_test_functions.missing_module' has not been found." with pytest.raises(ModuleNotFoundError, match=msg): implement_for_test_functions.missing_module() def test_implement_for_missing_version(): - msg = r"Supported version of 'missing_version' has not been found." + msg = r"Supported version of 'test_utils.implement_for_test_functions.missing_version' has not been found." with pytest.raises(ModuleNotFoundError, match=msg): implement_for_test_functions.missing_version() def test_implement_for_reset(): assert implement_for_test_functions.select_correct_version() == "0.3+" - _impl = implement_for._implementations - assert _impl is implement_for._implementations - implement_for.reset() + _impl = copy(implement_for._implementations) + name = implement_for.func_name(implement_for_test_functions.select_correct_version) + for setter in implement_for._setters: + if implement_for.func_name(setter.fn) == name and setter.fn() != "0.3+": + setter.module_set() + assert implement_for_test_functions.select_correct_version() != "0.3+" + implement_for.reset(_impl) assert implement_for_test_functions.select_correct_version() == "0.3+" - assert _impl is not implement_for._implementations @pytest.mark.parametrize( @@ -234,13 +239,19 @@ def test_set_gym_environments( expected_fn_gymnasium = impfor.fn with set_gym_backend(gymnasium): - assert _utils_internal._set_gym_environments == expected_fn_gymnasium + assert ( + _utils_internal._set_gym_environments == expected_fn_gymnasium + ), expected_fn_gym with set_gym_backend(gym): - assert _utils_internal._set_gym_environments == expected_fn_gym + assert ( + _utils_internal._set_gym_environments == expected_fn_gym + ), expected_fn_gymnasium with set_gym_backend(gymnasium): - assert _utils_internal._set_gym_environments == expected_fn_gymnasium + assert ( + _utils_internal._set_gym_environments == expected_fn_gymnasium + ), expected_fn_gym def test_set_gym_environments_no_version_gymnasium_found(): @@ -253,10 +264,12 @@ def test_set_gym_environments_no_version_gymnasium_found(): import gymnasium + assert gymnasium.__version__ == "0.26.0" + # this version of gymnasium does not exist in implement_for # therefore, set_gym_backend will not set anything and raise an ImportError. msg = f"could not set anything related to gym backend {gymnasium_name} with version={gymnasium_version}." - with pytest.raises(ImportError, match=msg) as exc_info: + with pytest.raises(ImportError, match=msg): with set_gym_backend(gymnasium): _utils_internal._set_gym_environments() @@ -280,6 +293,69 @@ def test_set_gym_backend_types(): assert gym_backend() == gym +# we check that the order where these funs are defined won't affect which is called +@implement_for("torch", "1.0", "1.8") +def torch_foo(): + return 0 + + +@implement_for("torch", "1.8", None) +def torch_foo(): # noqa: F811 + return 1 + + +@implement_for("torch", None, "1.0") +def torch_foo(): # noqa: F811 + return 1 + + +def test_set_gym_nested(): + mock_gym = uncallable(mock.MagicMock()) + mock_gym.__version__ = "0.21.0" + mock_gym.__name__ = "gym" + sys.modules["gym"] = mock_gym + + mock_gymnasium = uncallable(mock.MagicMock()) + mock_gymnasium.__version__ = "0.28.0" + mock_gymnasium.__name__ = "gymnasium" + sys.modules["gymnasium"] = mock_gymnasium + + import gym + import gymnasium + + assert torch_foo() == 1 + + class MockGym: + _is_batched = False + + with set_gym_backend(gym): + GymWrapper._output_transform( + MockGym, (1, 2, True, {}) + ) # would break with gymnasium + assert torch_foo() == 1 + with set_gym_backend(gymnasium): + GymWrapper._output_transform( + MockGym, (1, 2, True, True, {}) + ) # would break with gym + assert torch_foo() == 1 + GymWrapper._output_transform( + MockGym, (1, 2, True, {}) + ) # would break with gymnasium + with set_gym_backend("gym"): + GymWrapper._output_transform( + MockGym, (1, 2, True, {}) + ) # would break with gymnasium + assert torch_foo() == 1 + with set_gym_backend("gymnasium"): + GymWrapper._output_transform( + MockGym, (1, 2, True, True, {}) + ) # would break with gym + assert torch_foo() == 1 + GymWrapper._output_transform( + MockGym, (1, 2, True, {}) + ) # would break with gymnasium + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 3197b4b95d3..7d807244f70 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -5,8 +5,14 @@ import os from warnings import warn +import torch + from torch import multiprocessing as mp +if torch.cuda.device_count() > 1: + n = torch.cuda.device_count() - 1 + os.environ["MUJOCO_EGL_DEVICE_ID"] = str(1 + (os.getpid() % n)) + from ._extension import _init_extension @@ -35,3 +41,7 @@ import torchrl.modules import torchrl.objectives import torchrl.trainers + +# Filter warnings in subprocesses: True by default given the multiple optional +# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. +filter_warnings_subprocess = True diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 8d590b05210..de23df28425 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import collections @@ -18,11 +19,13 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, TypeVar, Union +from typing import Any, Callable, cast, Dict, TypeVar, Union import numpy as np import torch from packaging.version import parse +from torch import multiprocessing as mp + VERBOSE = strtobool(os.environ.get("VERBOSE", "0")) _os_is_windows = sys.platform == "win32" @@ -248,6 +251,7 @@ class implement_for: # Stores pointers to fitting implementations: dict[func_name] = func_pointer _implementations = {} _setters = [] + _cache_modules = {} def __init__( self, @@ -269,27 +273,58 @@ def check_version(version, from_version, to_version): @staticmethod def get_class_that_defined_method(f): """Returns the class of a method, if it is defined, and None otherwise.""" - return f.__globals__.get(f.__qualname__.split(".")[0], None) + out = f.__globals__.get(f.__qualname__.split(".")[0], None) + return out - @property - def func_name(self): - return self.fn.__name__ + @classmethod + def func_name(cls, fn): + # produces a name like torchrl.module.Class.method or torchrl.module.function + first = str(fn).split(".")[0][len(" str: + @classmethod + def import_module(cls, module_name: Union[Callable, str]) -> str: """Imports module and returns its version.""" if not callable(module_name): - module = import_module(module_name) + module = cls._cache_modules.get(module_name, None) + if module is None: + if module_name in sys.modules: + sys.modules[module_name] = module = import_module(module_name) + else: + cls._cache_modules[module_name] = module = import_module( + module_name + ) else: module = module_name() return module.__version__ @@ -298,7 +333,7 @@ def __call__(self, fn): self.fn = fn # If the module is missing replace the function with the mock. - func_name = self.func_name + func_name = self.func_name(self.fn) implementations = implement_for._implementations @wraps(fn) @@ -307,7 +342,7 @@ def unsupported(*args, **kwargs): f"Supported version of '{func_name}' has not been found." ) - do_set = False + self.do_set = False # Return fitting implementation if it was encountered before. if func_name in implementations: try: @@ -320,36 +355,45 @@ def unsupported(*args, **kwargs): f"Got multiple backends for {func_name}. " f"Using the last queried ({module} with version {version})." ) - do_set = True - if not do_set: - return implementations[func_name] + self.do_set = True + if not self.do_set: + return implementations[func_name].fn except ModuleNotFoundError: # then it's ok, there is no conflict - return implementations[func_name] + return implementations[func_name].fn else: try: version = self.import_module(self.module_name) if self.check_version(version, self.from_version, self.to_version): - do_set = True + self.do_set = True except ModuleNotFoundError: return unsupported - if do_set: - implementations[func_name] = fn + if self.do_set: self.module_set() return fn return unsupported @classmethod - def reset(cls, setters=None): + def reset(cls, setters_dict: Dict[str, implement_for] = None): + """Resets the setters in setter_dict. + + ``setter_dict`` is a copy of implementations. We just need to iterate through its + values and call :meth:`~.module_set` for each. + + """ if VERBOSE: print("resetting implement_for") - if setters is None: - setters = copy(cls._setters) - cls._setters = [] - cls._implementations = {} - for setter in setters: - setter(setter.fn) - cls._setters.append(setter) + if setters_dict is None: + setters_dict = copy(cls._implementations) + for setter in setters_dict.values(): + setter.module_set() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"module_name={self.module_name}({self.from_version, self.to_version}), " + f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})" + ) def accept_remote_rref_invocation(func): @@ -529,3 +573,26 @@ def clone(self): def get_trace(): """A simple debugging util to spot where a function is being called.""" traceback.print_stack() + + +class _ProcessNoWarn(mp.Process): + """A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess.""" + + @wraps(mp.Process.__init__) + def __init__(self, *args, num_threads=None, **kwargs): + import torchrl + + self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess + self.num_threads = num_threads + super().__init__(*args, **kwargs) + + def run(self, *args, **kwargs): + if self.num_threads is not None: + torch.set_num_threads(self.num_threads) + if self.filter_warnings_subprocess: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return mp.Process.run(self, *args, **kwargs) + return mp.Process.run(self, *args, **kwargs) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 74d2271851d..afd8ae61765 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import _pickle import abc import inspect @@ -30,23 +32,25 @@ from torchrl._utils import ( _check_for_faulty_process, + _ProcessNoWarn, accept_remote_rref_udf_invocation, prod, RL_WARNINGS, VERBOSE, ) from torchrl.collectors.utils import split_trajectories -from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( + _aggregate_resets, _convert_exploration_type, ExplorationType, set_exploration_type, step_mdp, + terminated_or_truncated, ) -from torchrl.envs.vec_env import _BatchedEnv _TIMEOUT = 1.0 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory @@ -69,8 +73,8 @@ class RandomPolicy: >>> from tensordict import TensorDict >>> from torchrl.data.tensor_specs import BoundedTensorSpec >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) - >>> actor = RandomPolicy(spec=action_spec) - >>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1] + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] """ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): @@ -78,7 +82,10 @@ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: - return td.set(self.action_key, self.action_spec.rand()) + if isinstance(self.action_spec, CompositeSpec): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) class _Interruptor: @@ -210,7 +217,7 @@ def _get_policy_and_device( raise ValueError( "env must be provided to _get_policy_and_device if policy is None" ) - policy = RandomPolicy(self.env.action_spec, self.env.action_key) + policy = RandomPolicy(self.env.input_spec["full_action_spec"]) elif isinstance(policy, nn.Module): # TODO: revisit these checks when we have determined whether arbitrary # callables should be supported as policies. @@ -245,7 +252,7 @@ def _get_policy_and_device( if not hasattr(self, "env") or self.env is None: out_keys = ["action"] else: - out_keys = [self.env.action_key] + out_keys = self.env.action_keys output = policy(**next_observation) if isinstance(output, tuple): @@ -386,12 +393,13 @@ class SyncDataCollector(DataCollectorBase): If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e. no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -419,9 +427,6 @@ class SyncDataCollector(DataCollectorBase): The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. Default is ``False``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -446,25 +451,28 @@ class SyncDataCollector(DataCollectorBase): ... break TensorDict( fields={ - action: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ - step_count: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), - "traj_ids: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([4, 50]), + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False), - done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), - mask: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([4, 50]), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([4, 50]), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> del collector @@ -493,18 +501,20 @@ def __init__( total_frames: int, device: DEVICE_TYPING = None, storing_device: DEVICE_TYPING = None, - create_env_kwargs: Optional[dict] = None, - max_frames_per_traj: int = -1, - init_random_frames: int = -1, + create_env_kwargs: dict | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - split_trajs: Optional[bool] = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, exploration_mode=None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None, ): + from torchrl.envs.batched_envs import _BatchedEnv + self.closed = True exploration_type = _convert_exploration_type( @@ -539,6 +549,8 @@ def __init__( self.storing_device = torch.device(storing_device) self.env: EnvBase = env self.closed = False + if not reset_when_done: + raise ValueError("reset_when_done is deprectated.") self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() @@ -558,14 +570,14 @@ def __init__( self.env: EnvBase = self.env.to(self.device) self.max_frames_per_traj = max_frames_per_traj - if self.max_frames_per_traj > 0: + if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: # let's check that there is no StepCounter yet for key in self.env.output_spec.keys(True, True): if isinstance(key, str): key = (key,) - if "truncated" in key: + if "step_count" in key: raise ValueError( - "A 'truncated' key is already present in the environment " + "A 'step_count' key is already present in the environment " "and the 'max_frames_per_traj' argument may conflict with " "a 'StepCounter' that has already been set. " "Possible solutions: Set max_frames_per_traj to 0 or " @@ -588,13 +600,26 @@ def __init__( self.total_frames = total_frames self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = init_random_frames + if ( + init_random_frames is not None + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.postproc = postproc if self.postproc is not None and hasattr(self.postproc, "to"): self.postproc.to(self.storing_device) if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( - f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " - f" this results in more frames_per_batch per iteration that requested." + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env})." "To silence this message, set the environment variable RL_WARNINGS to False." ) self.requested_frames_per_batch = frames_per_batch @@ -612,64 +637,45 @@ def __init__( ) with torch.no_grad(): - self._tensordict_out = env.fake_tensordict() + self._tensordict_out = self.env.fake_tensordict() + # If the policy has a valid spec, we use it if ( - hasattr(self.policy, "spec") - and self.policy.spec is not None - and all( - v is not None for v in self.policy.spec.values(True, True) - ) # if a spec is None, we don't know anything about it - # and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys) - and any( - key not in self._tensordict_out.keys(isinstance(key, tuple)) - for key in self.policy.spec.keys(True, True) - ) - ): - # if policy spec is non-empty, all the values are not None and the keys - # match the out_keys we assume the user has given all relevant information - # the policy could have more keys than the env: - policy_spec = self.policy.spec - if policy_spec.ndim < self._tensordict_out.ndim: - policy_spec = policy_spec.expand(self._tensordict_out.shape) - for key, spec in policy_spec.items(True, True): - if key in self._tensordict_out.keys(isinstance(key, tuple)): - continue - self._tensordict_out.set(key, spec.zero()) - self._tensordict_out = ( - self._tensordict_out.unsqueeze(-1) - .expand(*env.batch_size, self.frames_per_batch) - .clone() - ) - elif ( hasattr(self.policy, "spec") and self.policy.spec is not None and all(v is not None for v in self.policy.spec.values(True, True)) - and all( - key in self._tensordict_out.keys(isinstance(key, tuple)) - for key in self.policy.spec.keys(True, True) - ) ): - # reach this if the policy has specs and they match with the fake tensordict - self._tensordict_out = ( - self._tensordict_out.unsqueeze(-1) - .expand(*env.batch_size, self.frames_per_batch) - .clone() - ) + if any( + key not in self._tensordict_out.keys(isinstance(key, tuple)) + for key in self.policy.spec.keys(True, True) + ): + # if policy spec is non-empty, all the values are not None and the keys + # match the out_keys we assume the user has given all relevant information + # the policy could have more keys than the env: + policy_spec = self.policy.spec + if policy_spec.ndim < self._tensordict_out.ndim: + policy_spec = policy_spec.expand(self._tensordict_out.shape) + for key, spec in policy_spec.items(True, True): + if key in self._tensordict_out.keys(isinstance(key, tuple)): + continue + self._tensordict_out.set(key, spec.zero()) + else: # otherwise, we perform a small number of steps with the policy to # determine the relevant keys with which to pre-populate _tensordict_out. # This is the safest thing to do if the spec has None fields or if there is # no spec at all. # See #505 for additional context. + self._tensordict_out.update(self._tensordict) with torch.no_grad(): - self._tensordict_out = self._tensordict_out.to(self.device) - self._tensordict_out = self.policy(self._tensordict_out).unsqueeze(-1) - self._tensordict_out = ( - self._tensordict_out.expand(*env.batch_size, self.frames_per_batch) - .clone() - .zero_() - ) - # in addition to outputs of the policy, we add traj_ids and step_count to + self._tensordict_out = self.policy(self._tensordict_out.to(self.device)) + + self._tensordict_out = ( + self._tensordict_out.unsqueeze(-1) + .expand(*env.batch_size, self.frames_per_batch) + .clone() + .zero_() + ) + # in addition to outputs of the policy, we add traj_ids to # _tensordict_out which will be collected during rollout self._tensordict_out = self._tensordict_out.to(self.storing_device) self._tensordict_out.set( @@ -684,13 +690,11 @@ def __init__( if split_trajs is None: split_trajs = False - elif not self.reset_when_done and split_trajs: - raise RuntimeError( - "Cannot split trajectories when reset_when_done is False." - ) self.split_trajs = split_trajs self._exclude_private_keys = True self.interruptor = interruptor + self._frames = 0 + self._iter = -1 # for RPC def next(self): @@ -717,9 +721,11 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: Examples: >>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) - >>> collector = SyncDataCollector(env_fn_parallel) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6 """ @@ -739,11 +745,9 @@ def iterator(self) -> Iterator[TensorDictBase]: stream = None with torch.cuda.stream(stream): total_frames = self.total_frames - i = -1 - self._frames = 0 - while True: - i += 1 - self._iter = i + + while self._frames < self.total_frames: + self._iter += 1 tensordict_out = self.rollout() self._frames += tensordict_out.numel() if self._frames >= total_frames: @@ -782,56 +786,40 @@ def iterator(self) -> Iterator[TensorDictBase]: # >>> assert data0["done"] is not data1["done"] yield tensordict_out.clone() - if self._frames >= self.total_frames: - break - def _step_and_maybe_reset(self) -> None: - done = self._tensordict.get(("next", self.env.done_key)) - truncated = self._tensordict.get(("next", "truncated"), None) self._tensordict = step_mdp( self._tensordict, - reward_key=self.env.reward_key, - done_key=self.env.done_key, - action_key=self.env.action_key, + reward_keys=self.env.reward_keys, + done_keys=self.env.done_keys, + action_keys=self.env.action_keys, ) - if not self.reset_when_done: return - - done_or_terminated = ( - (done | truncated) if truncated is not None else done.clone() + td_reset = self._tensordict.clone(False) + any_done = terminated_or_truncated( + td_reset, + full_done_spec=self.env.output_spec["full_done_spec"], + key="_reset", ) - if done_or_terminated.any(): + + if any_done: traj_ids = self._tensordict.get(("collector", "traj_ids")) traj_ids = traj_ids.clone() # collectors do not support passing other tensors than `"_reset"` # to `reset()`. - _reset = done_or_terminated - td_reset = self._tensordict.select().set("_reset", _reset) + traj_sop = _aggregate_resets(td_reset, reset_keys=self.env.reset_keys) td_reset = self.env.reset(td_reset) - td_reset.del_("_reset") - traj_done_or_terminated = done_or_terminated.sum( - tuple(range(self._tensordict.batch_dims, done_or_terminated.ndim)), - dtype=torch.bool, - ) + if td_reset.batch_dims: # better cloning here than when passing the td for stacking - # cloning is necessary to avoid modifying dones in-place - self._tensordict = self._tensordict.clone() - self._tensordict.get_sub_tensordict(traj_done_or_terminated).update( - td_reset[traj_done_or_terminated] - ) + # cloning is necessary to avoid modifying entries in-place + self._tensordict = torch.where(traj_sop, td_reset, self._tensordict) else: self._tensordict.update(td_reset) - done = self._tensordict.get(self.env.done_key) - if done.any(): - raise RuntimeError( - f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." - ) - traj_ids[traj_done_or_terminated] = traj_ids.max() + torch.arange( - 1, traj_done_or_terminated.sum() + 1, device=traj_ids.device + traj_ids[traj_sop] = traj_ids.max() + torch.arange( + 1, traj_sop.sum() + 1, device=traj_ids.device ) self._tensordict.set(("collector", "traj_ids"), traj_ids) @@ -851,12 +839,14 @@ def rollout(self) -> TensorDictBase: tensordicts = [] with set_exploration_type(self.exploration_type): for t in range(self.frames_per_batch): - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): self.env.rand_step(self._tensordict) else: self.policy(self._tensordict) self.env.step(self._tensordict) - # we must clone all the values, since the step / traj_id updates are done in-place tensordicts.append(self._tensordict.to(self.storing_device)) @@ -898,7 +888,7 @@ def rollout(self) -> TensorDictBase: def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata - md = self._tensordict["collector"].clone() + md = self._tensordict.get("collector").clone() if index is not None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: @@ -910,7 +900,7 @@ def reset(self, index=None, **kwargs) -> None: ) _reset[index] = 1 self._tensordict[index].zero_() - self._tensordict["_reset"] = _reset + self._tensordict.set("_reset", _reset) else: _reset = None self._tensordict.zero_() @@ -947,6 +937,8 @@ def state_dict(self) -> OrderedDict: `"env_state_dict"`. """ + from torchrl.envs.batched_envs import _BatchedEnv + if isinstance(self.env, TransformedEnv): env_state_dict = self.env.transform.state_dict() elif isinstance(self.env, _BatchedEnv): @@ -963,6 +955,8 @@ def state_dict(self) -> OrderedDict: else: state_dict = OrderedDict(env_state_dict=env_state_dict) + state_dict.update({"frames": self._frames, "iter": self._iter}) + return state_dict def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: @@ -978,6 +972,8 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) if strict or "policy_state_dict" in state_dict: self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs) + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] def __repr__(self) -> str: env_str = indent(f"env={self.env}", 4 * " ") @@ -1039,12 +1035,13 @@ class _MultiDataCollector(DataCollectorBase): If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e. no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -1075,6 +1072,14 @@ class _MultiDataCollector(DataCollectorBase): Defaults to ``False``. preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + """ def __init__( @@ -1092,8 +1097,8 @@ def __init__( device: DEVICE_TYPING = None, storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, create_env_kwargs: Optional[Sequence[dict]] = None, - max_frames_per_traj: int = -1, - init_random_frames: int = -1, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, @@ -1104,11 +1109,17 @@ def __init__( update_at_each_batch: bool = False, devices=None, storing_devices=None, + num_threads: int = None, + num_sub_threads: int = 1, ): exploration_type = _convert_exploration_type( exploration_mode=exploration_mode, exploration_type=exploration_type ) self.closed = True + if num_threads is None: + num_threads = len(create_env_fn) + 1 # 1 more thread for this proc + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads self.create_env_fn = create_env_fn self.num_workers = len(create_env_fn) self.create_env_kwargs = ( @@ -1262,6 +1273,8 @@ def device_err_msg(device_name, devices_list): self.interruptor = None self._run_processes() self._exclude_private_keys = True + self._frames = 0 + self._iter = -1 @property def frames_per_batch_worker(self): @@ -1283,6 +1296,7 @@ def _queue_len(self) -> int: raise NotImplementedError def _run_processes(self) -> None: + torch.set_num_threads(self.num_threads) queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] self.pipes = [] @@ -1314,7 +1328,11 @@ def _run_processes(self) -> None: "idx": i, "interruptor": self.interruptor, } - proc = mp.Process(target=_main_async_collector, kwargs=kwargs) + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) # proc.daemon can't be set as daemonic processes may be launched by the process itself try: proc.start() @@ -1332,6 +1350,7 @@ def _run_processes(self) -> None: pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) + for pipe_parent in self.pipes: msg = pipe_parent.recv() if msg != "instantiated": raise RuntimeError(msg) @@ -1373,7 +1392,9 @@ def _shutdown_main(self) -> None: continue for proc in self.procs: - proc.join(10.0) + exitcode = proc.join(1.0) + if exitcode is None: + proc.terminate() self.queue_out.close() for pipe in self.pipes: pipe.close() @@ -1393,9 +1414,13 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: environment. Examples: - >>> env_fn = lambda: GymEnv("Pendulum-v0") + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) - >>> collector = SyncDataCollector(env_fn_parallel) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) >>> out_seed = collector.set_seed(1) # out_seed = 6 """ @@ -1444,6 +1469,7 @@ def state_dict(self) -> OrderedDict: if msg != "state_dict": raise RuntimeError(f"Expected msg='state_dict', got {msg}") state_dict[f"worker{idx}"] = _state_dict + state_dict.update({"frames": self._frames, "iter": self._iter}) return state_dict @@ -1461,6 +1487,8 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: _, msg = self.pipes[idx].recv() if msg != "loaded": raise RuntimeError(f"Expected msg='loaded', got {msg}") + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] @accept_remote_rref_udf_invocation @@ -1533,25 +1561,28 @@ class MultiSyncDataCollector(_MultiDataCollector): ... break TensorDict( fields={ - action: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ - step_count: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), - traj_ids: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([4, 50]), + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False), - done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), - mask: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([4, 50]), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([4, 50]), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> collector.shutdown() @@ -1609,27 +1640,29 @@ def _queue_len(self) -> int: return self.num_workers def iterator(self) -> Iterator[TensorDictBase]: - i = -1 - frames = 0 + self.buffers = {} dones = [False for _ in range(self.num_workers)] workers_frames = [0 for _ in range(self.num_workers)] same_device = None self.out_buffer = None - while not all(dones) and frames < self.total_frames: + while not all(dones) and self._frames < self.total_frames: _check_for_faulty_process(self.procs) if self.update_at_each_batch: self.update_policy_weights_() for idx in range(self.num_workers): - if frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): msg = "continue_random" else: msg = "continue" self.pipes[idx].send((None, msg)) - i += 1 + self._iter += 1 max_traj_idx = None if self.interruptor is not None and self.preemptive_threshold < 1.0: @@ -1684,10 +1717,10 @@ def iterator(self) -> Iterator[TensorDictBase]: if self.split_trajs: out = split_trajectories(self.out_buffer, prefix="collector") - frames += out.get(("collector", "mask")).sum().item() + self._frames += out.get(("collector", "mask")).sum().item() else: out = self.out_buffer.clone() - frames += prod(out.shape) + self._frames += prod(out.shape) if self.postprocs: self.postprocs = self.postprocs.to(out.device) out = self.postprocs(out) @@ -1764,25 +1797,28 @@ class MultiaSyncDataCollector(_MultiDataCollector): ... break TensorDict( fields={ - action: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ - step_count: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), - traj_ids: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([4, 50]), + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False), - done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), - mask: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([4, 50]), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([4, 50]), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> collector.shutdown() @@ -1856,18 +1892,16 @@ def iterator(self) -> Iterator[TensorDictBase]: self.update_policy_weights_() for i in range(self.num_workers): - if self.init_random_frames > 0: + if self.init_random_frames is not None and self.init_random_frames > 0: self.pipes[i].send((None, "continue_random")) else: self.pipes[i].send((None, "continue")) self.running = True - i = -1 - self._frames = 0 workers_frames = [0 for _ in range(self.num_workers)] while self._frames < self.total_frames: _check_for_faulty_process(self.procs) - i += 1 + self._iter += 1 idx, j, out = self._get_from_queue() worker_frames = out.numel() @@ -1880,7 +1914,10 @@ def iterator(self) -> Iterator[TensorDictBase]: # the function blocks here until the next item is asked, hence we send the message to the # worker to keep on working in the meantime before the yield statement - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): msg = "continue_random" else: msg = "continue" @@ -1907,7 +1944,10 @@ def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: raise Exception("self.queue_out is full") if self.running: for idx in range(self.num_workers): - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): self.pipes[idx].send((idx, "continue_random")) else: self.pipes[idx].send((idx, "continue")) @@ -1941,14 +1981,14 @@ class aSyncDataCollector(MultiaSyncDataCollector): environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Default is -1 (i.e. no maximum number of steps) + Defaults to ``None`` (i.e. no maximum number of steps) frames_per_batch (int): Time-length of a batch. reset_at_each_iter and frames_per_batch == n_steps are equivalent configurations. - default: 200 + Defaults to ``200`` init_random_frames (int): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - default=-1 (i.e. no random frames) + Defaults to ``None`` (i.e. no random frames) reset_at_each_iter (bool): Whether or not environments should be reset for each batch. default=False. postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a @@ -1983,9 +2023,9 @@ def __init__( ] = None, total_frames: Optional[int] = -1, create_env_kwargs: Optional[dict] = None, - max_frames_per_traj: int = -1, + max_frames_per_traj: int | None = None, frames_per_batch: int = 200, - init_random_frames: int = -1, + init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index eb0459698d5..752a09231c0 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -14,9 +14,9 @@ import torch.cuda from tensordict import TensorDict -from torch import multiprocessing as mp, nn +from torch import nn -from torchrl._utils import VERBOSE +from torchrl._utils import _ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -31,7 +31,8 @@ ) from torchrl.collectors.utils import split_trajectories from torchrl.data.utils import CloudpickleWrapper -from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.common import EnvBase +from torchrl.envs.env_creator import EnvCreator from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None @@ -610,7 +611,7 @@ def _init_worker_dist_mp(self, i): if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) TCP_PORT = self.tcp_port - job = mp.Process( + job = _ProcessNoWarn( target=_distributed_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 9211077c2e9..c05da8c5a0f 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -14,7 +14,8 @@ SyncDataCollector, ) from torchrl.collectors.utils import split_trajectories -from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.common import EnvBase +from torchrl.envs.env_creator import EnvCreator logger = logging.getLogger(__name__) diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index fc0a40c9ece..5fef2dd1666 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -32,10 +32,10 @@ SUBMITIT_ERR = err import torch.cuda from tensordict import TensorDict -from torch import multiprocessing as mp, nn +from torch import nn from torch.distributed import rpc -from torchrl._utils import VERBOSE +from torchrl._utils import _ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -44,7 +44,8 @@ MultiSyncDataCollector, SyncDataCollector, ) -from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.common import EnvBase +from torchrl.envs.env_creator import EnvCreator def _rpc_init_collection_node( @@ -446,7 +447,7 @@ def _init_worker_rpc(self, executor, i): print("job id", job.job_id) # ID of your job return job elif self.launcher == "mp": - job = mp.Process( + job = _ProcessNoWarn( target=_rpc_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index d4662105444..66e55318832 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -13,8 +13,8 @@ import torch.cuda from tensordict import TensorDict -from torch import multiprocessing as mp, nn -from torchrl._utils import VERBOSE +from torch import nn +from torchrl._utils import _ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -29,7 +29,8 @@ ) from torchrl.collectors.utils import split_trajectories from torchrl.data.utils import CloudpickleWrapper -from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.common import EnvBase +from torchrl.envs.env_creator import EnvCreator from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None @@ -396,7 +397,7 @@ def _init_worker_dist_mp(self, i): env_make = self.env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - job = mp.Process( + job = _ProcessNoWarn( target=_distributed_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index ccc7b89809c..87145f26847 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -58,15 +58,21 @@ def split_trajectories( traj_ids = rollout_tensordict.get(traj_ids_key, None) done = rollout_tensordict.get(("next", "done")) - truncated = rollout_tensordict.get( - ("next", "truncated"), torch.zeros((), device=done.device, dtype=torch.bool) - ) - done = done | truncated if traj_ids is None: - traj_ids = done.cumsum(rollout_tensordict.ndim - 1) + idx = (slice(None),) * (rollout_tensordict.ndim - 1) + (slice(None, -1),) + done_sel = done[idx] + pads = [1, 0] + pads = [0, 0] * (done.ndim - rollout_tensordict.ndim) + pads + done_sel = torch.nn.functional.pad(done_sel, pads) + if done_sel.shape != done.shape: + raise RuntimeError( + f"done and done_sel have different shape {done.shape} - {done_sel.shape} " + ) + traj_ids = done_sel.cumsum(rollout_tensordict.ndim - 1) + traj_ids = traj_ids.squeeze(-1) if rollout_tensordict.ndim > 1: for i in range(1, rollout_tensordict.shape[0]): - traj_ids[i] += traj_ids[i - 1].max() + traj_ids[i] += traj_ids[i - 1].max() + 1 rollout_tensordict.set(traj_ids_key, traj_ids) splits = traj_ids.view(-1) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index b3ad5caf02d..4c90146ac7f 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -46,3 +46,4 @@ UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) +from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index d6c32083e23..9516a6e8102 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -2,15 +2,22 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import Callable, Optional +import os +import urllib +import warnings +from typing import Callable import numpy as np import torch + +from tensordict import PersistentTensorDict from tensordict.tensordict import make_tensordict from torchrl.collectors.utils import split_trajectories +from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS from torchrl.data.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -28,7 +35,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): If present, metadata will be written in ``D4RLExperienceReplay.metadata`` and excluded from the dataset. - The transitions are reconstructed using ``done = terminal | timeout`` and + The transitions are reconstructed using ``done = terminated | truncated`` and the ``("next", "observation")`` of ``"done"`` states are zeroed. Args: @@ -50,8 +57,8 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which - is recovered via ``done = timeout | terminal``. In other words, - it is assumed that any ``timeout`` or ``terminal`` signal is + is recovered via ``done = truncated | terminated``. In other words, + it is assumed that any ``truncated`` or ``terminated`` signal is equivalent to the end of a trajectory. For some datasets from ``D4RL``, this may not be true. It is up to the user to make accurate choices regarding this usage of ``split_trajs``. @@ -72,21 +79,28 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): .. note:: The keys in ``from_env=True`` and ``from_env=False`` *may* unexpectedly - differ. In particular, the ``"timeout"`` key (used to determine the + differ. In particular, the ``"truncated"`` key (used to determine the end of an episode) may be absent when ``from_env=False`` but present otherwise, leading to a different slicing when ``traj_splits`` is enabled. - - use_timeout_as_done (bool, optional): if ``True``, ``done = terminal | timeout``. - Otherwise, only the ``terminal`` key is used. Defaults to ``True``. + direct_download (bool): if ``True``, the data will be downloaded without + requiring D4RL. If ``None``, if ``d4rl`` is present in the env it will + be used to download the dataset, otherwise the download will fall back + on ``direct_download=True``. + This is not compatible with ``from_env=True``. + Defaults to ``None``. + use_truncated_as_done (bool, optional): if ``True``, ``done = terminated | truncated``. + Otherwise, only the ``terminated`` key is used. Defaults to ``True``. + terminate_on_end (bool, optional): Set ``done=True`` on the last timestep + in a trajectory. Default is ``False``, and will discard the + last timestep in each trajectory. **env_kwargs (key-value pairs): additional kwargs for - :func:`d4rl.qlearning_dataset`. Supports ``terminate_on_end`` - (``False`` by default) or other kwargs if defined by D4RL library. + :func:`d4rl.qlearning_dataset`. Examples: >>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay >>> from torchrl.envs import ObservationNorm - >>> data = D4RLExperienceReplay("maze2d-umaze-v1") + >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128) >>> # we can append transforms to the dataset >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0)) >>> data.sample(128) @@ -109,33 +123,70 @@ def __init__( self, name, batch_size: int, - sampler: Optional[Sampler] = None, - writer: Optional[Writer] = None, - collate_fn: Optional[Callable] = None, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, pin_memory: bool = False, - prefetch: Optional[int] = None, - transform: Optional["Transform"] = None, # noqa-F821 + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 split_trajs: bool = False, - from_env: bool = True, - use_timeout_as_done: bool = True, + from_env: bool = None, + use_truncated_as_done: bool = True, + direct_download: bool = None, + terminate_on_end: bool = None, **env_kwargs, ): - - type(self)._import_d4rl() - - if not self._has_d4rl: - raise ImportError("Could not import d4rl") from self.D4RL_ERR + if from_env is None: + warnings.warn( + "from_env will soon default to ``False``, ie the data will be " + "downloaded without relying on d4rl by default. " + "For now, ``True`` will still be the default. " + "To disable this warning, explicitly pass the ``from_env`` argument " + "during construction of the dataset.", + category=DeprecationWarning, + ) + from_env = True self.from_env = from_env - self.use_timeout_as_done = use_timeout_as_done - if from_env: - dataset = self._get_dataset_from_env(name, env_kwargs) + self.use_truncated_as_done = use_truncated_as_done + + if not from_env and direct_download is None: + self._import_d4rl() + direct_download = not self._has_d4rl + + if not direct_download: + if terminate_on_end is None: + # we use the default of d4rl + terminate_on_end = False + self._import_d4rl() + + if not self._has_d4rl: + raise ImportError("Could not import d4rl") from self.D4RL_ERR + + if from_env: + dataset = self._get_dataset_from_env(name, env_kwargs) + else: + if self.use_truncated_as_done: + warnings.warn( + "Using use_truncated_as_done=True + terminate_on_end=True " + "with from_env=False may not have the intended effect " + "as the timeouts (truncation) " + "can be absent from the static dataset." + ) + env_kwargs.update({"terminate_on_end": terminate_on_end}) + dataset = self._get_dataset_direct(name, env_kwargs) else: - dataset = self._get_dataset_direct(name, env_kwargs) + if terminate_on_end is False: + raise ValueError( + "Using terminate_on_end=False is not compatible with direct_download=True." + ) + dataset = self._get_dataset_direct_download(name, env_kwargs) # Fill unknown next states with 0 dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 if split_trajs: dataset = split_trajectories(dataset) + dataset["next", "done"][:, -1] = True + storage = LazyMemmapStorage(dataset.shape[0]) super().__init__( batch_size=batch_size, @@ -149,6 +200,23 @@ def __init__( ) self.extend(dataset) + def _get_dataset_direct_download(self, name, env_kwargs): + """Directly download and use a D4RL dataset.""" + if env_kwargs: + raise RuntimeError( + f"Cannot pass env_kwargs when `direct_download=True`. Got env_kwargs keys: {env_kwargs.keys()}" + ) + url = D4RL_DATASETS.get(name, None) + if url is None: + raise KeyError(f"Env {name} not found.") + h5path = _download_dataset_from_url(url) + # h5path_parent = Path(h5path).parent + dataset = PersistentTensorDict.from_h5(h5path) + dataset = dataset.to_tensordict() + with dataset.unlock_(): + dataset = self._process_data_from_env(dataset) + return dataset + def _get_dataset_direct(self, name, env_kwargs): from torchrl.envs.libs.gym import GymWrapper @@ -179,22 +247,19 @@ def _get_dataset_direct(self, name, env_kwargs): dataset = dataset.unflatten_keys("/") else: self.metadata = {} - dataset.rename_key("observations", "observation") + dataset.rename_key_("observations", "observation") dataset.set("next", dataset.select()) - dataset.rename_key("next_observations", ("next", "observation")) - dataset.rename_key("terminals", "terminal") + dataset.rename_key_("next_observations", ("next", "observation")) + dataset.rename_key_("terminals", "terminated") if "timeouts" in dataset.keys(): - dataset.rename_key("timeouts", "timeout") - if self.use_timeout_as_done: - dataset.set( - "done", - dataset.get("terminal") - | dataset.get("timeout", torch.zeros((), dtype=torch.bool)), - ) + dataset.rename_key_("timeouts", "truncated") + if self.use_truncated_as_done: + done = dataset.get("terminated") | dataset.get("truncated", False) + dataset.set("done", done) else: - dataset.set("done", dataset.get("terminal")) - dataset.rename_key("rewards", "reward") - dataset.rename_key("actions", "action") + dataset.set("done", dataset.get("terminated")) + dataset.rename_key_("rewards", "reward") + dataset.rename_key_("actions", "action") # let's make sure that the dtypes match what's expected for key, spec in env.observation_spec.items(True, True): @@ -202,13 +267,16 @@ def _get_dataset_direct(self, name, env_kwargs): dataset["next", key] = dataset["next", key].to(spec.dtype) dataset["action"] = dataset["action"].to(env.action_spec.dtype) dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) - dataset["done"] = dataset["done"].bool() - dataset["done"] = dataset["done"].unsqueeze(-1) - # dataset.rename_key("next_observations", "next/observation") + # format done etc + dataset["done"] = dataset["done"].bool().unsqueeze(-1) + dataset["terminated"] = dataset["terminated"].bool().unsqueeze(-1) + if "truncated" in dataset.keys(): + dataset["truncated"] = dataset["truncated"].bool().unsqueeze(-1) + # dataset.rename_key_("next_observations", "next/observation") dataset["reward"] = dataset["reward"].unsqueeze(-1) dataset["next"].update( - dataset.select("reward", "done", "terminal", "timeout", strict=False) + dataset.select("reward", "done", "terminated", "truncated", strict=False) ) dataset = ( dataset.clone() @@ -239,6 +307,10 @@ def _get_dataset_from_env(self, name, env_kwargs): } ) dataset = dataset.unflatten_keys("/") + dataset = self._process_data_from_env(dataset, env) + return dataset + + def _process_data_from_env(self, dataset, env=None): if "metadata" in dataset.keys(): metadata = dataset.get("metadata") dataset = dataset.exclude("metadata") @@ -249,53 +321,99 @@ def _get_dataset_from_env(self, name, env_kwargs): else: self.metadata = {} - dataset.rename_key("observations", "observation") - dataset.rename_key("terminals", "terminal") + dataset.rename_key_("observations", "observation") + dataset.rename_key_("terminals", "terminated") if "timeouts" in dataset.keys(): - dataset.rename_key("timeouts", "timeout") - if self.use_timeout_as_done: + dataset.rename_key_("timeouts", "truncated") + if self.use_truncated_as_done: dataset.set( "done", - dataset.get("terminal") - | dataset.get("timeout", torch.zeros((), dtype=torch.bool)), + dataset.get("terminated") | dataset.get("truncated", False), ) else: - dataset.set("done", dataset.get("terminal")) - dataset.rename_key("rewards", "reward") - dataset.rename_key("actions", "action") + dataset.set("done", dataset.get("terminated")) + + dataset.rename_key_("rewards", "reward") + dataset.rename_key_("actions", "action") try: - dataset.rename_key("infos", "info") + dataset.rename_key_("infos", "info") except KeyError: pass # let's make sure that the dtypes match what's expected - for key, spec in env.observation_spec.items(True, True): - dataset[key] = dataset[key].to(spec.dtype) - dataset["action"] = dataset["action"].to(env.action_spec.dtype) - dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) - dataset["done"] = dataset["done"].bool() + if env is not None: + for key, spec in env.observation_spec.items(True, True): + dataset[key] = dataset[key].to(spec.dtype) + dataset["action"] = dataset["action"].to(env.action_spec.dtype) + dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) + + # format done + dataset["done"] = dataset["done"].bool().unsqueeze(-1) + dataset["terminated"] = dataset["terminated"].bool().unsqueeze(-1) + if "truncated" in dataset.keys(): + dataset["truncated"] = dataset["truncated"].bool().unsqueeze(-1) - dataset["done"] = dataset["done"].unsqueeze(-1) - # dataset.rename_key("next_observations", "next/observation") dataset["reward"] = dataset["reward"].unsqueeze(-1) dataset = dataset[:-1].set( "next", dataset.select("observation", "info", strict=False)[1:], ) dataset["next"].update( - dataset.select("reward", "done", "terminal", "timeout", strict=False) + dataset.select("reward", "done", "terminated", "truncated", strict=False) ) dataset = ( dataset.clone() ) # make sure that all tensors have a different data_ptr self._shift_reward_done(dataset) - self.specs = env.specs.clone() + if env is not None: + self.specs = env.specs.clone() + else: + self.specs = None return dataset def _shift_reward_done(self, dataset): dataset["reward"] = dataset["reward"].clone() - dataset["done"] = dataset["done"].clone() dataset["reward"][1:] = dataset["reward"][:-1].clone() - dataset["done"][1:] = dataset["done"][:-1].clone() dataset["reward"][0] = 0 - dataset["done"][0] = 0 + for key in ("done", "terminated", "truncated"): + if key not in dataset.keys(): + continue + dataset[key] = dataset[key].clone() + dataset[key][1:] = dataset[key][:-1].clone() + dataset[key][0] = 0 + + +def _download_dataset_from_url(dataset_url): + dataset_filepath = _filepath_from_url(dataset_url) + if not os.path.exists(dataset_filepath): + print("Downloading dataset:", dataset_url, "to", dataset_filepath) + urllib.request.urlretrieve(dataset_url, dataset_filepath) + if not os.path.exists(dataset_filepath): + raise IOError("Failed to download dataset from %s" % dataset_url) + return dataset_filepath + + +def _filepath_from_url(dataset_url): + _, dataset_name = os.path.split(dataset_url) + dataset_filepath = os.path.join(DATASET_PATH, dataset_name) + return dataset_filepath + + +def _set_dataset_path(path): + global DATASET_PATH + DATASET_PATH = path + os.makedirs(path, exist_ok=True) + + +_set_dataset_path( + os.environ.get( + "D4RL_DATASET_DIR", os.path.expanduser("~/.cache/torchrl/data/d4rl/datasets") + ) +) + +if __name__ == "__main__": + data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) + print(data) + for sample in data: + print(sample) + break diff --git a/torchrl/data/datasets/d4rl_infos.py b/torchrl/data/datasets/d4rl_infos.py new file mode 100644 index 00000000000..e9790ea04f9 --- /dev/null +++ b/torchrl/data/datasets/d4rl_infos.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +D4RL_DATASETS = { + "maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", + "maze2d-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5", + "maze2d-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5", + "maze2d-large-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5", + "maze2d-eval-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5", + "maze2d-eval-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5", + "maze2d-eval-large-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5", + "maze2d-open-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5", + "maze2d-umaze-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5", + "maze2d-medium-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5", + "maze2d-large-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5", + "maze2d-eval-umaze-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5", + "maze2d-eval-medium-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5", + "maze2d-eval-large-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5", + "minigrid-fourrooms-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5", + "minigrid-fourrooms-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5", + "pen-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5", + "pen-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5", + "pen-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5", + "hammer-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5", + "hammer-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5", + "hammer-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5", + "relocate-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5", + "relocate-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5", + "relocate-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5", + "door-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5", + "door-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5", + "door-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5", + "halfcheetah-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5", + "halfcheetah-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5", + "halfcheetah-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5", + "halfcheetah-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5", + "halfcheetah-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5", + "walker2d-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5", + "walker2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5", + "walker2d-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5", + "walker2d-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5", + "walker2d-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5", + "hopper-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5", + "hopper-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5", + "hopper-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5", + "hopper-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5", + "hopper-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5", + "ant-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5", + "ant-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5", + "ant-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5", + "ant-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5", + "ant-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5", + "ant-random-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5", + "antmaze-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5", + "antmaze-umaze-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-medium-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-medium-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-large-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-large-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-umaze-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5", + "antmaze-umaze-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", + "antmaze-medium-play-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5", + "antmaze-medium-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", + "antmaze-large-play-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5", + "antmaze-large-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", + "flow-ring-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", + "flow-ring-controller-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", + "flow-merge-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", + "flow-merge-controller-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", + "kitchen-complete-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5", + "kitchen-partial-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5", + "kitchen-mixed-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5", + "carla-lane-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", + "carla-town-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", + "carla-town-full-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", + "bullet-halfcheetah-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", + "bullet-halfcheetah-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", + "bullet-halfcheetah-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", + "bullet-halfcheetah-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", + "bullet-halfcheetah-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", + "bullet-hopper-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", + "bullet-hopper-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", + "bullet-hopper-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", + "bullet-hopper-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", + "bullet-hopper-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", + "bullet-ant-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", + "bullet-ant-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", + "bullet-ant-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", + "bullet-ant-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", + "bullet-ant-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", + "bullet-walker2d-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", + "bullet-walker2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", + "bullet-walker2d-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", + "bullet-walker2d-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", + "bullet-walker2d-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", + "bullet-maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", + "bullet-maze2d-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", + "bullet-maze2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", + "bullet-maze2d-large-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", + "halfcheetah-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_random-v1.hdf5", + "halfcheetah-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_random-v2.hdf5", + "halfcheetah-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium-v1.hdf5", + "halfcheetah-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium-v2.hdf5", + "halfcheetah-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_expert-v1.hdf5", + "halfcheetah-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_expert-v2.hdf5", + "halfcheetah-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium_replay-v1.hdf5", + "halfcheetah-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium_replay-v2.hdf5", + "halfcheetah-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_full_replay-v1.hdf5", + "halfcheetah-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_full_replay-v2.hdf5", + "halfcheetah-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium_expert-v1.hdf5", + "halfcheetah-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium_expert-v2.hdf5", + "hopper-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_random-v1.hdf5", + "hopper-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_random-v2.hdf5", + "hopper-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium-v1.hdf5", + "hopper-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium-v2.hdf5", + "hopper-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_expert-v1.hdf5", + "hopper-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_expert-v2.hdf5", + "hopper-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium_replay-v1.hdf5", + "hopper-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_replay-v2.hdf5", + "hopper-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_full_replay-v1.hdf5", + "hopper-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_full_replay-v2.hdf5", + "hopper-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium_expert-v1.hdf5", + "hopper-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_expert-v2.hdf5", + "walker2d-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_random-v1.hdf5", + "walker2d-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_random-v2.hdf5", + "walker2d-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium-v1.hdf5", + "walker2d-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium-v2.hdf5", + "walker2d-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_expert-v1.hdf5", + "walker2d-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_expert-v2.hdf5", + "walker2d-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium_replay-v1.hdf5", + "walker2d-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium_replay-v2.hdf5", + "walker2d-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_full_replay-v1.hdf5", + "walker2d-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_full_replay-v2.hdf5", + "walker2d-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium_expert-v1.hdf5", + "walker2d-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium_expert-v2.hdf5", + "ant-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_random-v1.hdf5", + "ant-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_random-v2.hdf5", + "ant-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium-v1.hdf5", + "ant-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium-v2.hdf5", + "ant-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_expert-v1.hdf5", + "ant-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_expert-v2.hdf5", + "ant-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium_replay-v1.hdf5", + "ant-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium_replay-v2.hdf5", + "ant-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_full_replay-v1.hdf5", + "ant-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_full_replay-v2.hdf5", + "ant-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium_expert-v1.hdf5", + "ant-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium_expert-v2.hdf5", + "hammer-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-human-v1.hdf5", + "hammer-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-expert-v1.hdf5", + "hammer-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-cloned-v1.hdf5", + "pen-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-human-v1.hdf5", + "pen-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-expert-v1.hdf5", + "pen-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-cloned-v1.hdf5", + "relocate-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-human-v1.hdf5", + "relocate-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-expert-v1.hdf5", + "relocate-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-cloned-v1.hdf5", + "door-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-human-v1.hdf5", + "door-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-expert-v1.hdf5", + "door-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-cloned-v1.hdf5", + "antmaze-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5", + "antmaze-umaze-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-medium-play-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-medium-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-large-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-large-play-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-umaze-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-eval-medium-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-eval-medium-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-large-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-large-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", + "door-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5", + "hammer-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5", + "pen-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5", + "relocate-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5", + "maze2d-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5", + "maze2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5", + "maze2d-large-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5", + "maze2d-umaze-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5", + "maze2d-medium-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5", + "maze2d-large-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5", + "carla-lane-render-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5", + "carla-town-render-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", +} diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 3369e2004bc..5d21d202eae 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -248,6 +248,13 @@ def add(self, data: Any) -> int: Returns: index where the data lives in the replay buffer. """ + if self._transform is not None and ( + is_tensor_collection(data) or len(self._transform) + ): + data = self._transform.inv(data) + return self._add(data) + + def _add(self, data): with self._replay_lock: index = self._writer.add(data) self._sampler.add(index) @@ -271,9 +278,9 @@ def extend(self, data: Sequence) -> torch.Tensor: Returns: Indices of the data added to the replay buffer. """ - if self._transform is not None and is_tensor_collection(data): - data = self._transform.inv(data) - elif self._transform is not None and len(self._transform): + if self._transform is not None and ( + is_tensor_collection(data) or len(self._transform) + ): data = self._transform.inv(data) return self._extend(data) @@ -410,6 +417,27 @@ def __iter__(self): data = self.sample() yield data + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + _replay_lock = state.pop("_replay_lock", None) + _futures_lock = state.pop("_futures_lock", None) + if _replay_lock is not None: + state["_replay_lock_placeholder"] = None + if _futures_lock is not None: + state["_futures_lock_placeholder"] = None + return state + + def __setstate__(self, state: Dict[str, Any]): + if "_replay_lock_placeholder" in state: + state.pop("_replay_lock_placeholder") + _replay_lock = threading.RLock() + state["_replay_lock"] = _replay_lock + if "_futures_lock_placeholder" in state: + state.pop("_futures_lock_placeholder") + _futures_lock = threading.RLock() + state["_futures_lock"] = _futures_lock + self.__dict__.update(state) + class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized replay buffer. @@ -634,13 +662,14 @@ def __init__(self, *, priority_key: str = "td_error", **kw) -> None: super().__init__(**kw) self.priority_key = priority_key - def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: + def _get_priority_item(self, tensordict: TensorDictBase) -> float: if "_data" in tensordict.keys(): tensordict = tensordict.get("_data") - if self.priority_key not in tensordict.keys(): + + priority = tensordict.get(self.priority_key, None) + if priority is None: return self._sampler.default_priority try: - priority = tensordict.get(self.priority_key) if priority.numel() > 1: priority = _reduce(priority, self._sampler.reduction) else: @@ -653,20 +682,42 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: ) return priority + def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: + if "_data" in tensordict.keys(): + tensordict = tensordict.get("_data") + + priority = tensordict.get(self.priority_key, None) + if priority is None: + return torch.tensor( + self._sampler.default_priority, + dtype=torch.float, + device=tensordict.device, + ).expand(tensordict.shape[0]) + + priority = priority.reshape(priority.shape[0], -1) + priority = _reduce(priority, self._sampler.reduction, dim=1) + + return priority + def add(self, data: TensorDictBase) -> int: + if self._transform is not None: + data = self._transform.inv(data) + if is_tensor_collection(data): data_add = TensorDict( { "_data": data, }, batch_size=[], + device=data.device, ) if data.batch_size: data_add["_rb_batch_size"] = torch.tensor(data.batch_size) else: data_add = data - index = super().add(data_add) + + index = super()._add(data_add) if is_tensor_collection(data_add): data_add.set("index", index) @@ -675,62 +726,50 @@ def add(self, data: TensorDictBase) -> int: self.update_tensordict_priority(data_add) return index - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - if is_tensor_collection(tensordicts): - tensordicts = TensorDict( - {"_data": tensordicts}, batch_size=tensordicts.batch_size[:1] - ) - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - # we keep track of the batch size to reinstantiate it when sampling - if "_rb_batch_size" in tensordicts.keys(): - raise KeyError( - "conflicting key '_rb_batch_size'. Consider removing from data." - ) - shape = torch.tensor(tensordicts.batch_size[1:]).expand( - tensordicts.batch_size[0], tensordicts.batch_dims - 1 + def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: + + tensordicts = TensorDict( + {"_data": tensordicts}, + batch_size=tensordicts.batch_size[:1], + ) + if tensordicts.batch_dims > 1: + # we want the tensordict to have one dimension only. The batch size + # of the sampled tensordicts can be changed thereafter + if not isinstance(tensordicts, LazyStackedTensorDict): + tensordicts = tensordicts.clone(recurse=False) + else: + tensordicts = tensordicts.contiguous() + # we keep track of the batch size to reinstantiate it when sampling + if "_rb_batch_size" in tensordicts.keys(): + raise KeyError( + "conflicting key '_rb_batch_size'. Consider removing from data." ) - tensordicts.set("_rb_batch_size", shape) - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, device=tensordicts.device, dtype=torch.int - ), + shape = torch.tensor(tensordicts.batch_size[1:]).expand( + tensordicts.batch_size[0], tensordicts.batch_dims - 1 ) - - if not is_tensor_collection(tensordicts): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts + tensordicts.set("_rb_batch_size", shape) + tensordicts.set( + "index", + torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int), + ) if self._transform is not None: - stacked_td.set("_data", self._transform.inv(stacked_td.get("_data"))) - - index = super()._extend(stacked_td) - # stacked_td.set( - # "index", - # torch.tensor(index, dtype=torch.int, device=stacked_td.device), - # inplace=True, - # ) - self.update_tensordict_priority(stacked_td) + data = self._transform.inv(tensordicts.get("_data")) + tensordicts.set("_data", data) + if data.device is not None: + tensordicts = tensordicts.to(data.device) + + index = super()._extend(tensordicts) + self.update_tensordict_priority(tensordicts) return index def update_tensordict_priority(self, data: TensorDictBase) -> None: if not isinstance(self._sampler, PrioritizedSampler): return if data.ndim: - priority = torch.tensor( - [self._get_priority(td) for td in data], - dtype=torch.float, - device=data.device, - ) + priority = self._get_priority_vector(data) else: - priority = self._get_priority(data) + priority = self._get_priority_item(data) index = data.get("index") while index.shape != priority.shape: # reduce index @@ -977,17 +1016,23 @@ def __call__(self, list_of_tds): return self.out -def _reduce(tensor: torch.Tensor, reduction: str): +def _reduce( + tensor: torch.Tensor, reduction: str, dim: Optional[int] = None +) -> Union[float, torch.Tensor]: """Reduces a tensor given the reduction method.""" if reduction == "max": - return tensor.max().item() + result = tensor.max(dim=dim) elif reduction == "min": - return tensor.min().item() + result = tensor.min(dim=dim) elif reduction == "mean": - return tensor.mean().item() + result = tensor.mean(dim=dim) elif reduction == "median": - return tensor.median().item() - raise NotImplementedError(f"Unknown reduction method {reduction}") + result = tensor.median(dim=dim) + else: + raise NotImplementedError(f"Unknown reduction method {reduction}") + if isinstance(result, tuple): + result = result[0] + return result.item() if dim is None else result def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 436c6cf76e3..f2b28f373b4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -171,11 +171,14 @@ class TensorStorage(Storage): """A storage for tensors and tensordicts. Args: - data (tensor or TensorDict): the data buffer to be used. + storage (tensor or TensorDict): the data buffer to be used. max_size (int): size of the storage, i.e. maximum number of elements stored in the buffer. device (torch.device, optional): device where the sampled tensors will be stored and sent. Default is :obj:`torch.device("cpu")`. + If "auto" is passed, the device is automatically gathered from the + first batch of data passed. This is not enabled by default to avoid + data placed on GPU by mistake, causing OOM issues. Examples: >>> data = TensorDict({ @@ -230,7 +233,7 @@ def __new__(cls, *args, **kwargs): cls._storage = None return super().__new__(cls) - def __init__(self, storage, max_size=None, device=None): + def __init__(self, storage, max_size=None, device="cpu"): if not ((storage is None) ^ (max_size is None)): if storage is None: raise ValueError("Expected storage to be non-null.") @@ -247,7 +250,13 @@ def __init__(self, storage, max_size=None, device=None): self._len = max_size else: self._len = 0 - self.device = device if device else torch.device("cpu") + self.device = ( + torch.device(device) + if device != "auto" + else storage.device + if storage is not None + else "auto" + ) self._storage = storage def state_dict(self) -> Dict[str, Any]: @@ -345,6 +354,9 @@ class LazyTensorStorage(TensorStorage): in the buffer. device (torch.device, optional): device where the sampled tensors will be stored and sent. Default is :obj:`torch.device("cpu")`. + If "auto" is passed, the device is automatically gathered from the + first batch of data passed. This is not enabled by default to avoid + data placed on GPU by mistake, causing OOM issues. Examples: >>> data = TensorDict({ @@ -396,12 +408,14 @@ class LazyTensorStorage(TensorStorage): """ - def __init__(self, max_size, device=None): - super().__init__(None, max_size, device=device) + def __init__(self, max_size, device="cpu"): + super().__init__(storage=None, max_size=max_size, device=device) def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if VERBOSE: print("Creating a TensorStorage...") + if self.device == "auto": + self.device = data.device if isinstance(data, torch.Tensor): # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype out = torch.empty( @@ -436,6 +450,9 @@ class LazyMemmapStorage(LazyTensorStorage): scratch_dir (str or path): directory where memmap-tensors will be written. device (torch.device, optional): device where the sampled tensors will be stored and sent. Default is :obj:`torch.device("cpu")`. + If ``None`` is provided, the device is automatically gathered from the + first batch of data passed. This is not enabled by default to avoid + data placed on GPU by mistake, causing OOM issues. Examples: >>> data = TensorDict({ @@ -486,7 +503,7 @@ class LazyMemmapStorage(LazyTensorStorage): """ - def __init__(self, max_size, scratch_dir=None, device=None): + def __init__(self, max_size, scratch_dir=None, device="cpu"): super().__init__(max_size) self.initialized = False self.scratch_dir = None @@ -494,7 +511,7 @@ def __init__(self, max_size, scratch_dir=None, device=None): self.scratch_dir = str(scratch_dir) if self.scratch_dir[-1] != "/": self.scratch_dir += "/" - self.device = device if device else torch.device("cpu") + self.device = torch.device(device) if device != "auto" else device self._len = 0 def state_dict(self) -> Dict[str, Any]: @@ -552,6 +569,8 @@ def load_state_dict(self, state_dict): def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if VERBOSE: print("Creating a MemmapStorage...") + if self.device == "auto": + self.device = data.device if isinstance(data, torch.Tensor): # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype out = MemmapTensor( @@ -682,7 +701,7 @@ def _get_default_collate(storage, _is_tensordict=False): return torch.utils.data._utils.collate.default_collate elif isinstance(storage, LazyMemmapStorage): return _collate_as_tensor - elif isinstance(storage, (LazyTensorStorage,)): + elif isinstance(storage, (TensorStorage,)): return _collate_contiguous else: raise NotImplementedError( diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index e2d10b19139..db2b6a418d6 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -308,16 +308,17 @@ def create_infinite_iterator(iterator): def get_dataloader( - batch_size, - block_size, - tensorclass_type, - device, - dataset_name=None, - infinite=True, - prefetch=0, - split="train", - root_dir=None, - from_disk=False, + batch_size: int, + block_size: int, + tensorclass_type: Type, + device: torch.device, + dataset_name: str | None = None, + infinite: bool = True, + prefetch: int = 0, + split: str = "train", + root_dir: str | None = None, + from_disk: bool = False, + num_workers: int | None = None, ): """Creates a dataset and returns a dataloader from it. @@ -346,9 +347,12 @@ def get_dataloader( from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk` will be used. Otherwise, :func:`datasets.load_dataset` will be used. Defaults to ``False``. + num_workers (int, optional): number of workers for :meth:`datasets.dataset.map` + which is called during tokenization. + Defaults to ``max(os.cpu_count() // 2, 1)``. Examples: - >>> from torchrl.data.rlhf.comparison import PairwiseDataset + >>> from torchrl.data.rlhf.reward import PairwiseDataset >>> dataloader = get_dataloader( ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu") >>> for d in dataloader: @@ -381,6 +385,7 @@ def get_dataloader( max_length=block_size, root_dir=root_dir, from_disk=from_disk, + num_workers=num_workers, ) out = TensorDictReplayBuffer( storage=TensorStorage(data), diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py index 9e97f1f9c1e..d534a95379e 100644 --- a/torchrl/data/rlhf/prompt.py +++ b/torchrl/data/rlhf/prompt.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from typing import Optional @@ -41,7 +42,13 @@ def mask_label(self, pad_token_id=50256): @classmethod def from_dataset( - cls, split, dataset_name=None, max_length=550, root_dir=None, from_disk=False + cls, + split, + dataset_name=None, + max_length=550, + root_dir=None, + from_disk=False, + num_workers: int | None = None, ): """Returns a :class:`PromptData` from a dataset name. @@ -56,6 +63,9 @@ def from_dataset( from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk` will be used. Otherwise, :func:`datasets.load_dataset` will be used. Defaults to ``False``. + num_workers (int, optional): number of workers for :meth:`datasets.dataset.map` + which is called during tokenization. + Defaults to ``max(os.cpu_count() // 2, 1)``. Returns: a :class:`PromptData` instance containing a memory-mapped version of the required dataset. @@ -85,6 +95,7 @@ def from_dataset( PromptTensorDictTokenizer, root_dir=root_dir, from_disk=from_disk, + num_workers=num_workers, ) data = loader.load() return cls(**data, labels=data["input_ids"], batch_size=data.shape) diff --git a/torchrl/data/rlhf/reward.py b/torchrl/data/rlhf/reward.py index 6726eb20c30..e7843e02f46 100644 --- a/torchrl/data/rlhf/reward.py +++ b/torchrl/data/rlhf/reward.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import importlib from typing import Optional @@ -66,7 +67,13 @@ class PairwiseDataset: @classmethod def from_dataset( - cls, split, dataset_name=None, max_length=550, root_dir=None, from_disk=False + cls, + split, + dataset_name: str | None = None, + max_length: int = 550, + root_dir: str | None = None, + from_disk: bool = False, + num_workers: int | None = None, ): """Returns a :class:`PairwiseDataset` from a dataset name. @@ -122,6 +129,7 @@ def from_dataset( pre_tokenization_hook, root_dir=root_dir, from_disk=from_disk, + num_workers=num_workers, ) data = loader.load() maxidx = data.shape[0] // 2 diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 2b22a7347dd..3cf2b6f7e4b 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -2,9 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import abc +import collections import importlib -from typing import Tuple +from typing import Sequence, Tuple +import numpy as np import torch from tensordict import TensorDict @@ -16,6 +21,91 @@ _has_transformers = importlib.util.find_spec("transformers") is not None +class KLControllerBase(abc.ABC): + """Base class for KL controllers. + + Each controller must implement an update method that takes the current KL value and + the number of steps and updates the kl_coef attribute of the wrapped model, + which will multiply the KL during calculation of the reward. + """ + + @abc.abstractmethod + def update(self, kl_values: float): + pass + + +class ConstantKLController(KLControllerBase): + """Constant KL Controller. + + This controller maintains a fixed coefficient no matter what values it is updated + with. + + Arguments: + model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + kl_coef (float): The coefficient to multiply KL with when calculating the + reward. + """ + + def __init__(self, model, kl_coef): + self.model = model + if not hasattr(model, "kl_coef"): + raise AttributeError( + "Model input to ConstantKLController doesn't have attribute 'kl_coef'" + ) + self.coef = kl_coef + self.model.kl_coef = self.coef + + def update(self, kl_values: Sequence[float] = None): + self.model.kl_coef = self.coef + + +class AdaptiveKLController(KLControllerBase): + """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". + + Arguments: + model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + init_kl_coef (float): The starting value of the coefficient. + target (float): The target KL value. When the observed KL is smaller, the + coefficient is decreased, thereby relaxing the KL penalty in the training + objective and allowing the model to stray further from the reference model. + When the observed KL is greater than the target, the KL coefficient is + increased, thereby pulling the model back towards the reference model. + horizon (int): Scaling factor to control how aggressively we update the + coefficient. + + Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 + Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py + """ + + def __init__(self, model, init_kl_coef: float, target: float, horizon: int): + self.model = model + self.coef = init_kl_coef + self.target = target + self.horizon = horizon + self.model.kl_coef = self.coef + + def update(self, kl_values: Sequence[float]): + """Update ``self.coef`` adaptively. + + Arguments: + kl_values (sequence of float): The current KL value between the newest policy and the initial + policy. + + """ + if kl_values is None: + raise ValueError( + f"The kl_values were not provided to {type(self)}. " + f"Make sure these values are provided for the scheduler to be updated " + f"accordingly. " + ) + n_steps = len(kl_values) + # renormalize kls + kl_value = -torch.tensor(kl_values).mean() / self.coef + proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ + mult = 1 + proportional_error * n_steps / self.horizon + self.coef *= mult # βₜ₊₁ + + class RolloutFromModel: """A class for performing rollouts with causal language models. @@ -33,10 +123,13 @@ class RolloutFromModel: reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given ``input_ids`` and ``attention_mask``, calculates rewards for each token and end_scores (the reward for the final token in each sequence). + kl_coef: (float, optional): initial kl coefficient. max_new_tokens (int, optional): the maximum length of the sequence. Defaults to 50. score_clip (float, optional): Scores from the reward model are clipped to the range ``(-score_clip, score_clip)``. Defaults to 10. + kl_scheduler (KLControllerBase, optional): the KL coefficient scheduler. + num_steps (int, optional): number of steps between two optimization. Examples: >>> from tensordict.nn import TensorDictModule @@ -87,7 +180,15 @@ class RolloutFromModel: EOS_TOKEN_ID = 50256 def __init__( - self, model, ref_model, reward_model, max_new_tokens=50, score_clip=10.0 + self, + model, + ref_model, + reward_model, + kl_coef=0.1, + max_new_tokens=50, + score_clip=10.0, + kl_scheduler: KLControllerBase | None = None, + num_steps: int | None = None, ): if not _has_transformers: raise ImportError( @@ -99,18 +200,23 @@ def __init__( self.reward_model = reward_model self.max_new_tokens = max_new_tokens self.score_clip = score_clip - - def kl_step(self): - """Makes a step in the KL coefficient schedule.""" - raise NotImplementedError + self.kl_coef = kl_coef + self.kl_scheduler = kl_scheduler + if num_steps is not None: + self._kl_queue = collections.deque(maxlen=num_steps) + else: + # we create a list. Value appended to it will be detached scalars so very cheap to store, + # even if the update is not called. + # The scheduler update will take care of erasing these values. + self._kl_queue = [] @torch.no_grad() - def rollout_from_data(self, batch, kl_coef=0.1): + def rollout_from_data(self, batch): generated, log_probs, log_ratio = self.generate(batch) - return self.create_rollout_td(batch, generated, log_probs, log_ratio, kl_coef) + return self.create_rollout_td(batch, generated, log_probs, log_ratio) @torch.no_grad() - def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1): + def create_rollout_td(self, batch, generated, log_probs, log_ratio): """A TensorDict wrapper for generated data. This function takes a batch plus the generated tokens and replicates the @@ -142,9 +248,11 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) part of the inputs that will be used for generating the next token. - ``("next", "attention_mask")``: updated attention_mask after token has been generated. Passed to the generative model on the next time step - - ``("next", "done")``: Boolean array indicating whether we've reached a + - ``("next", "terminated")``: Boolean array indicating whether we've reached a terminal state (either because we generated EOS token or because we reached the token limit) + - ``("next", "done")``: Boolean array indicating whether we've reached a + final state. Currently a copy of ``"terminated"``. - ``("next", "reward")``: The reward received at each time step - ``("next", "reward_raw")``: The raw reward from the reward model, without the KL term. This is mainly for debugging and logging, it is not used in @@ -155,7 +263,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) rollout_generated = self._get_rollout_generated(generated, batch) rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool() - done = self._get_done_status(generated, batch) + done, terminated = self._get_done_status(generated, batch) action = self._get_action(generated, batch) end_scores, end_scores_labels = self._get_end_scores( rollout_generated, rollout_attention_mask, batch @@ -167,7 +275,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) ) reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) reward_raw = reward_raw * done - reward_kl = -kl_coef * log_ratio.unsqueeze(-1) + reward_kl = -self.kl_coef * log_ratio.unsqueeze(-1) reward = reward_raw + reward_kl td = { "action": action, @@ -178,11 +286,13 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) "input_ids": rollout_generated[:, 1:].clone(), "attention_mask": rollout_attention_mask[:, 1:].clone(), "done": done, + "terminated": terminated, "reward": reward, "reward_raw": reward_raw, "reward_kl": reward_kl, }, } + self._kl_queue.append(reward_kl.detach().mean()) return TensorDict( td, batch_size=done.shape[:2], device=generated.device ).refine_names(..., "time") @@ -206,13 +316,33 @@ def _get_done_status(self, generated, batch): (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex, torch.tensor(self.max_new_tokens) - 1, ) - done = torch.zeros( + truncated_idx = ( + torch.tensor(self.max_new_tokens, device=generated.device).expand_as( + done_idx + ) + - 1 + ) + zeros = torch.zeros( done_idx.numel(), self.max_new_tokens, dtype=torch.bool, device=generated.device, ) - return done.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) + truncated = zeros.scatter(-1, truncated_idx.unsqueeze(-1), 1).unsqueeze(-1) + done = zeros.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) + terminated = ( + done & ~truncated + ) # we assume that if it's not truncated, it was terminated + return truncated | terminated, terminated + + print("batch.prompt_rindex", batch.prompt_rindex) + print("generated", generated.shape) + terminated = (generated == self.EOS_TOKEN_ID)[..., -batch.prompt_rindex :] + terminated = terminated.int().cumsum(-1).bool() + done = terminated.clone() + done[..., self.max_new_tokens - 1] = 1 + print("self.max_new_tokens", self.max_new_tokens) + return done.unsqueeze(-1), terminated.unsqueeze(-1) def _get_action(self, generated, batch): # the sequence of actions for each trajectory is just the generated token ids @@ -390,3 +520,11 @@ def generate(self, batch: PromptData, generation_config=None): log_ratio = self._log_ratio(generated, batch.prompt_rindex) return generated, log_probs_gen, log_ratio + + def step_scheduler(self): + # recover true kl + self.kl_scheduler.update(self._kl_queue) + if isinstance(self._kl_queue, (list, collections.deque)): + # remove all values + while len(self._kl_queue): + self._kl_queue.remove(self._kl_queue[0]) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4d69949b964..d2d8c3233d9 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -32,8 +32,8 @@ import numpy as np import torch from tensordict import unravel_key -from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import _getitem_batch_size +from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -69,6 +69,11 @@ _DEFAULT_SHAPE = torch.Size((1,)) DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined." +NOT_IMPLEMENTED_ERROR = NotImplementedError( + "method is not currently implemented." + " If you are interested in this feature please submit" + " an issue at https://github.com/pytorch/rl/issues" +) def _default_dtype_and_device( @@ -343,63 +348,89 @@ def clone(self) -> DiscreteBox: @dataclass(repr=False) class ContinuousBox(Box): - """A continuous box of values, in between a minimum and a maximum.""" + """A continuous box of values, in between a minimum (self.low) and a maximum (self.high).""" - _minimum: torch.Tensor - _maximum: torch.Tensor + _low: torch.Tensor + _high: torch.Tensor device: torch.device = None # We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used. + @property + def low(self): + return self._low.to(self.device) + + @property + def high(self): + return self._high.to(self.device) + + @low.setter + def low(self, value): + self.device = value.device + self._low = value.cpu() + + @high.setter + def high(self, value): + self.device = value.device + self._high = value.cpu() + @property def minimum(self): - return self._minimum.to(self.device) + warnings.warn( + f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low", + category=DeprecationWarning, + ) + return self._low.to(self.device) @property def maximum(self): - return self._maximum.to(self.device) + warnings.warn( + f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.low", + category=DeprecationWarning, + ) + return self._high.to(self.device) - @minimum.setter - def minimum(self, value): + @low.setter + def low(self, value): self.device = value.device - self._minimum = value.cpu() + self._low = value.cpu() - @maximum.setter - def maximum(self, value): + @high.setter + def high(self, value): self.device = value.device - self._maximum = value.cpu() + self._high = value.cpu() def __post_init__(self): - self.minimum = self.minimum.clone() - self.maximum = self.maximum.clone() + self.low = self.low.clone() + self.high = self.high.clone() def __iter__(self): - yield self.minimum - yield self.maximum + yield self.low + yield self.high def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: - return self.__class__(self.minimum.to(dest), self.maximum.to(dest)) + return self.__class__(self.low.to(dest), self.high.to(dest)) def clone(self) -> ContinuousBox: - return self.__class__(self.minimum.clone(), self.maximum.clone()) + return self.__class__(self.low.clone(), self.high.clone()) def __repr__(self): min_str = indent( - f"\nminimum=Tensor(shape={self.minimum.shape}, device={self.minimum.device}, dtype={self.minimum.dtype}, contiguous={self.maximum.is_contiguous()})", + f"\nlow=Tensor(shape={self.low.shape}, device={self.low.device}, dtype={self.low.dtype}, contiguous={self.high.is_contiguous()})", " " * 4, ) max_str = indent( - f"\nmaximum=Tensor(shape={self.maximum.shape}, device={self.maximum.device}, dtype={self.maximum.dtype}, contiguous={self.maximum.is_contiguous()})", + f"\nhigh=Tensor(shape={self.high.shape}, device={self.high.device}, dtype={self.high.dtype}, contiguous={self.high.is_contiguous()})", " " * 4, ) - return f"{self.__class__.__name__}({min_str}, {max_str})" + return f"{self.__class__.__name__}({min_str},{max_str})" def __eq__(self, other): return ( type(self) == type(other) - and self.minimum.dtype == other.minimum.dtype - and self.maximum.dtype == other.maximum.dtype - and torch.equal(self.minimum, other.minimum) - and torch.equal(self.maximum, other.maximum) + and self.low.dtype == other.low.dtype + and self.high.dtype == other.high.dtype + and torch.equal(self.low, other.low) + and torch.equal(self.high, other.high) ) @@ -523,23 +554,27 @@ def encode( if not ignore_device: val = torch.tensor(val, device=self.device, dtype=self.dtype) else: - val = torch.as_tensor(val, dtype=self.dtype) - if val.shape[-len(self.shape) :] != self.shape: + val = torch.tensor(val, dtype=self.dtype) + if val.shape != self.shape: + # if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end - if ( - val.shape[-len(self.shape) :] == self.shape[:-1] - and self.shape[-1] == 1 - ): + if val.shape == self.shape and self.shape[-1] == 1: val = val.unsqueeze(-1) else: - raise RuntimeError( - f"Shape mismatch: the value has shape {val.shape} which " - f"is incompatible with the spec shape {self.shape}." - ) + try: + val = val.reshape(self.shape) + except Exception as err: + raise RuntimeError( + f"Shape mismatch: the value has shape {val.shape} which " + f"is incompatible with the spec shape {self.shape}." + ) from err if _CHECK_SPEC_ENCODE: self.assert_is_in(val) return val + def __ne__(self, other): + return not (self == other) + def __setattr__(self, key, value): if key == "shape": value = torch.Size(value) @@ -741,13 +776,16 @@ def __torch_function__( ) return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) + def unbind(self, dim: int): + raise NotImplementedError + T = TypeVar("T") class _LazyStackedMixin(Generic[T]): def __init__(self, *specs: tuple[T, ...], dim: int) -> None: - self._specs = specs + self._specs = list(specs) self.dim = dim if self.dim < 0: self.dim = len(self.shape) + self.dim @@ -839,46 +877,12 @@ def __getitem__(self, item): return out return torch.stack(list(out), 0) - @property - def shape(self): - shape = list(self._specs[0].shape) - dim = self.dim - if dim < 0: - dim = len(shape) + dim + 1 - shape.insert(dim, len(self._specs)) - return torch.Size(shape) - def clone(self) -> T: - return torch.stack([spec.clone() for spec in self._specs], 0) + return torch.stack([spec.clone() for spec in self._specs], self.stack_dim) - def expand(self, *shape): - if len(shape) == 1 and not isinstance(shape[0], (int,)): - return self.expand(*shape[0]) - expand_shape = shape[: -len(self.shape)] - existing_shape = self.shape - shape_check = shape[-len(self.shape) :] - for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): - if size1 != size2 and size1 != 1: - raise RuntimeError( - f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" - ) - elif size1 != size2 and size1 == 1 and _i == self.dim: - # if we're expanding along the stack dim we just need to clone the existing spec - return torch.stack( - [self._specs[0].clone() for _ in range(size2)], self.dim - ).expand(*shape) - if _i != len(self.shape) - 1: - raise RuntimeError( - f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." - ) - # remove the stack dim from the expanded shape, which we know to match - unstack_shape = list(expand_shape) + [ - s for i, s in enumerate(shape_check) if i != self.dim - ] - return torch.stack( - [spec.expand(unstack_shape) for spec in self._specs], - self.dim + len(expand_shape), - ) + @property + def stack_dim(self): + return self.dim def zero(self, shape=None) -> TensorDictBase: if shape is not None: @@ -897,6 +901,75 @@ def rand(self, shape=None) -> TensorDictBase: def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: return torch.stack([spec.to(dest) for spec in self._specs], self.dim) + def unbind(self, dim: int): + if dim == self.stack_dim: + return self._specs + shape = self.shape + if dim < 0 or dim > self.ndim - 1 or shape[dim] == -1: + raise ValueError( + f"Provided dim {dim} is not valid for unbinding shape {shape}" + ) + else: + raise ValueError( + f"A {type(self)} instance can only be unbound along its stack dimension. Expected {self.stack_dim}, received {dim} instead." + ) + + def unsqueeze(self, dim: int): + if dim < 0: + new_dim = dim + len(self.shape) + 1 + else: + new_dim = dim + if new_dim > len(self.shape) or new_dim < 0: + raise ValueError(f"Cannot unsqueeze along dim {dim}.") + if new_dim > self.dim: + # unsqueeze 2, stack is on 1 => unsqueeze 1, stack along 1 + new_stack_dim = self.dim + new_dim = new_dim - 1 + else: + # unsqueeze 0, stack is on 1 => unsqueeze 0, stack on 1 + new_stack_dim = self.dim + 1 + return torch.stack( + [spec.unsqueeze(new_dim) for spec in self._specs], dim=new_stack_dim + ) + + def squeeze(self, dim: int = None): + if dim is None: + size = self.shape + if len(size) == 1 or size.count(1) == 0: + return self + first_singleton_dim = size.index(1) + + squeezed_dict = self.squeeze(first_singleton_dim) + return squeezed_dict.squeeze(dim=None) + + if dim < 0: + new_dim = self.ndim + dim + else: + new_dim = dim + + if self.shape and (new_dim >= self.ndim or new_dim < 0): + raise RuntimeError( + f"squeezing is allowed for dims comprised between 0 and " + f"spec.ndim only. Got dim={dim} and shape" + f"={self.shape}." + ) + + if new_dim >= self.ndim or self.shape[new_dim] != 1: + return self + + if new_dim == self.dim: + return self._specs[0] + if new_dim > self.dim: + # squeeze 2, stack is on 1 => squeeze 1, stack along 1 + new_stack_dim = self.dim + new_dim = new_dim - 1 + else: + # squeeze 0, stack is on 1 => squeeze 0, stack on 1 + new_stack_dim = self.dim - 1 + return torch.stack( + [spec.squeeze(new_dim) for spec in self._specs], dim=new_stack_dim + ) + class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): """A lazy representation of a stack of tensor specs. @@ -912,13 +985,18 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): """ - @property - def space(self): - return self._specs[0].space - def __eq__(self, other): - # requires unbind to be implemented - pass + if not isinstance(other, LazyStackedTensorSpec): + return False + if len(self._specs) != len(other._specs): + return False + for _spec1, _spec2 in zip(self._specs, other._specs): + if _spec1 != _spec2: + return False + return True + + def __len__(self): + return self.shape[0] def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: if safe is None: @@ -933,30 +1011,15 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: spec.assert_is_in(v) return val.detach().cpu().numpy() - def __len__(self): - pass - - def project(self, val: TensorDictBase) -> TensorDictBase: - pass - def __repr__(self): shape_str = "shape=" + str(self.shape) - space_str = "space=" + str(self._specs[0].space) device_str = "device=" + str(self.device) dtype_str = "dtype=" + str(self.dtype) domain_str = "domain=" + str(self._specs[0].domain) - sub_string = ", ".join( - [shape_str, space_str, device_str, dtype_str, domain_str] - ) - string = f"{self.__class__.__name__}(\n {sub_string})" + sub_string = ", ".join([shape_str, device_str, dtype_str, domain_str]) + string = f"LazyStacked{self._specs[0].__class__.__name__}(\n {sub_string})" return string - def __iter__(self): - pass - - def __setitem__(self, key, value): - pass - @property def device(self) -> DEVICE_TYPING: return self._specs[0].device @@ -968,16 +1031,75 @@ def ndim(self): def ndimension(self): return len(self.shape) - def set(self, name, spec): - if spec is not None: - shape = spec.shape - if shape[: self.ndim] != self.shape: - raise ValueError( - "The shape of the spec and the CompositeSpec mismatch: the first " - f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " - f"CompositeSpec.shape={self.shape}." + @property + def shape(self): + first_shape = self._specs[0].shape + shape = [] + for i in range(len(first_shape)): + homo_dim = True + for spec in self._specs: + if spec.shape[i] != first_shape[i]: + homo_dim = False + break + shape.append(first_shape[i] if homo_dim else -1) + + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self._specs)) + return torch.Size(shape) + + def expand(self, *shape): + if len(shape) == 1 and not isinstance(shape[0], (int,)): + return self.expand(*shape[0]) + expand_shape = shape[: -len(self.shape)] + existing_shape = self.shape + shape_check = shape[-len(self.shape) :] + for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + if size1 != size2 and size1 != 1: + raise RuntimeError( + f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" ) - self._specs[name] = spec + elif size1 != size2 and size1 == 1 and _i == self.dim: + # if we're expanding along the stack dim we just need to clone the existing spec + return torch.stack( + [self._specs[0].clone() for _ in range(size2)], self.dim + ).expand(*shape) + if _i != len(self.shape) - 1: + raise RuntimeError( + f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + ) + # remove the stack dim from the expanded shape, which we know to match + shape_check = [s for i, s in enumerate(shape_check) if i != self.dim] + specs = [] + for spec in self._specs: + spec_shape = [] + for dim_check, spec_dim in zip(shape_check, spec.shape): + spec_shape.append(dim_check if dim_check != -1 else spec_dim) + unstack_shape = list(expand_shape) + list(spec_shape) + specs.append(spec.expand(unstack_shape)) + return torch.stack( + specs, + self.dim + len(expand_shape), + ) + + def type_check(self, value: torch.Tensor, key: str = None) -> None: + raise NOT_IMPLEMENTED_ERROR + + def is_in(self, val) -> bool: + raise NOT_IMPLEMENTED_ERROR + + @property + def space(self): + raise NOT_IMPLEMENTED_ERROR + + def _project(self, val: TensorDictBase) -> TensorDictBase: + raise NOT_IMPLEMENTED_ERROR + + def encode( + self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False + ) -> torch.Tensor: + raise NOT_IMPLEMENTED_ERROR @dataclass(repr=False) @@ -1028,10 +1150,10 @@ def __init__( n: int, shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, + dtype: Optional[Union[str, torch.dtype]] = torch.bool, use_register: bool = False, + mask: torch.Tensor | None = None, ): - dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register space = DiscreteBox(n) @@ -1045,6 +1167,17 @@ def __init__( f"Got n={space.n} and shape={shape}." ) super().__init__(shape, space, device, dtype, "discrete") + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(self.shape) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): @@ -1061,6 +1194,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: device=dest_device, dtype=dest_dtype, use_register=self.use_register, + mask=self.mask.to(dest) if self.mask is not None else None, ) def clone(self) -> OneHotDiscreteTensorSpec: @@ -1070,6 +1204,7 @@ def clone(self) -> OneHotDiscreteTensorSpec: device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=self.mask.clone() if self.mask is not None else None, ) def expand(self, *shape): @@ -1084,41 +1219,79 @@ def expand(self, *shape): f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) + mask = self.mask + if mask is not None: + mask = mask.expand(shape) return self.__class__( - n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def squeeze(self, dim=None): if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): - raise ValueError( - "Final dimension of OneHotDiscreteTensorSpec must remain unchanged" - ) + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _squeezed_shape(self.shape, dim) if shape is None: return self - + mask = self.mask + if mask is not None: + mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=mask, ) def unsqueeze(self, dim: int): if dim in (len(self.shape), -1): - raise ValueError( - "Final dimension of OneHotDiscreteTensorSpec must remain unchanged" - ) + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _unsqueezed_shape(self.shape, dim) + mask = self.mask + if mask is not None: + mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=mask, + ) + + def unbind(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + mask = self.mask + if mask is not None: + mask = mask.unbind(dim) + else: + mask = (None,) * self.shape[dim] + return tuple( + self.__class__( + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + use_register=self.use_register, + mask=mask[i], + ) + for i in range(self.shape[dim]) ) def rand(self, shape=None) -> torch.Tensor: @@ -1126,10 +1299,21 @@ def rand(self, shape=None) -> torch.Tensor: shape = self.shape[:-1] else: shape = torch.Size([*shape, *self.shape[:-1]]) - n = self.space.n - m = torch.randint(n, (*shape, 1), device=self.device) - out = torch.zeros((*shape, n), device=self.device, dtype=self.dtype) - out.scatter_(-1, m, 1) + mask = self.mask + if mask is None: + n = self.space.n + m = torch.randint(n, shape, device=self.device) + else: + mask = mask.expand(*shape, mask.shape[-1]) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + out = torch.nn.functional.one_hot(m, self.space.n).to(self.dtype) + # torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype) + # out.scatter_(-1, m, 1) return out def encode( @@ -1141,9 +1325,9 @@ def encode( ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if ignore_device: - val = torch.tensor(val, dtype=self.dtype) + val = torch.tensor(val) else: - val = torch.tensor(val, dtype=self.dtype, device=self.device) + val = torch.tensor(val, device=self.device) if space is None: space = self.space @@ -1156,7 +1340,7 @@ def encode( if (val >= space.n).any(): raise AssertionError("Value must be less than action space.") - val = torch.nn.functional.one_hot(val.long(), space.n) + val = torch.nn.functional.one_hot(val.long(), space.n).to(self.dtype) return val def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: @@ -1166,7 +1350,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: if not isinstance(val, torch.Tensor): raise NotImplementedError self.assert_is_in(val) - val = val.argmax(-1).cpu().numpy() + val = val.long().argmax(-1).cpu().numpy() if self.use_register: inv_reg = self.space.register.inverse() vals = [] @@ -1197,18 +1381,43 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=self.mask[idx] if self.mask is not None else None, ) def _project(self, val: torch.Tensor) -> torch.Tensor: - # idx = val.sum(-1) != 1 - out = torch.nn.functional.gumbel_softmax(val.to(torch.float)) - out = (out == out.max(dim=-1, keepdim=True)[0]).to(torch.long) - return out + if self.mask is None: + out = torch.multinomial(val.to(torch.float), 1).squeeze(-1) + out = torch.nn.functional.one_hot(out, self.space.n).to(self.dtype) + return out + shape = self.mask.shape + shape = torch.broadcast_shapes(shape, val.shape) + mask_expand = self.mask.expand(shape) + gathered = mask_expand & val + oob = ~gathered.any(-1) + new_val = torch.multinomial(mask_expand[oob].float(), 1) + val = val.clone() + val[oob] = 0 + val[oob] = torch.scatter(val[oob], -1, new_val, 1) + return val def is_in(self, val: torch.Tensor) -> bool: - return (val.sum(-1) == 1).all() + if self.mask is None: + return (val.sum(-1) == 1).all() + shape = self.mask.shape + shape = torch.broadcast_shapes(shape, val.shape) + mask_expand = self.mask.expand(shape) + gathered = mask_expand & val + return gathered.any(-1).all() def __eq__(self, other): + if not hasattr(other, "mask"): + return False + mask_equal = (self.mask is None and other.mask is None) or ( + isinstance(self.mask, torch.Tensor) + and isinstance(other.mask, torch.Tensor) + and (self.mask.shape == other.mask.shape) + and (self.mask == other.mask).all() + ) return ( type(self) == type(other) and self.shape == other.shape @@ -1217,6 +1426,7 @@ def __eq__(self, other): and self.dtype == other.dtype and self.domain == other.domain and self.use_register == other.use_register + and mask_equal ) def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: @@ -1235,15 +1445,15 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) - return val.argmax(-1) + return val.long().argmax(-1) def to_categorical_spec(self) -> DiscreteTensorSpec: """Converts the spec to the equivalent categorical spec.""" return DiscreteTensorSpec( self.space.n, device=self.device, - dtype=self.dtype, shape=self.shape[:-1], + mask=self.mask, ) @@ -1252,41 +1462,62 @@ class BoundedTensorSpec(TensorSpec): """A bounded continuous tensor spec. Args: - minimum (np.ndarray, torch.Tensor or number): lower bound of the box. - maximum (np.ndarray, torch.Tensor or number): upper bound of the box. + low (np.ndarray, torch.Tensor or number): lower bound of the box. + high (np.ndarray, torch.Tensor or number): upper bound of the box. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. """ # SPEC_HANDLED_FUNCTIONS = {} + DEPRECATED_KWARGS = ( + "The `minimum` and `maximum` keyword arguments are now " + "deprecated in favour of `low` and `high`." + ) + CONFLICTING_KWARGS = ( + "The keyword arguments {} and {} conflict. Only one of these can be passed." + ) def __init__( self, - minimum: Union[float, torch.Tensor, np.ndarray], - maximum: Union[float, torch.Tensor, np.ndarray], + low: Union[float, torch.Tensor, np.ndarray] = None, + high: Union[float, torch.Tensor, np.ndarray] = None, shape: Optional[Union[torch.Size, int]] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[torch.dtype, str]] = None, + **kwargs, ): + if "maximum" in kwargs: + if high is not None: + raise TypeError(self.CONFLICTING_KWARGS.format("high", "maximum")) + high = kwargs.pop("maximum") + warnings.warn(self.DEPRECATED_KWARGS, category=DeprecationWarning) + if "minimum" in kwargs: + if low is not None: + raise TypeError(self.CONFLICTING_KWARGS.format("low", "minimum")) + low = kwargs.pop("minimum") + warnings.warn(self.DEPRECATED_KWARGS, category=DeprecationWarning) + if len(kwargs): + raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") + dtype, device = _default_dtype_and_device(dtype, device) if dtype is None: dtype = torch.get_default_dtype() if device is None: device = torch._get_default_device() - if not isinstance(minimum, torch.Tensor): - minimum = torch.tensor(minimum, dtype=dtype, device=device) - if not isinstance(maximum, torch.Tensor): - maximum = torch.tensor(maximum, dtype=dtype, device=device) - if maximum.device != device: - maximum = maximum.to(device) - if minimum.device != device: - minimum = minimum.to(device) - if dtype is not None and minimum.dtype is not dtype: - minimum = minimum.to(dtype) - if dtype is not None and maximum.dtype is not dtype: - maximum = maximum.to(dtype) + if not isinstance(low, torch.Tensor): + low = torch.tensor(low, dtype=dtype, device=device) + if not isinstance(high, torch.Tensor): + high = torch.tensor(high, dtype=dtype, device=device) + if high.device != device: + high = high.to(device) + if low.device != device: + low = low.to(device) + if dtype is not None and low.dtype is not dtype: + low = low.to(dtype) + if dtype is not None and high.dtype is not dtype: + high = high.to(dtype) err_msg = ( "BoundedTensorSpec requires the shape to be explicitely (via " "the shape argument) or implicitely defined (via either the " @@ -1300,45 +1531,41 @@ def __init__( else: shape = torch.Size(list(shape)) - if maximum.ndimension(): - if shape is not None and shape != maximum.shape: + if high.ndimension(): + if shape is not None and shape != high.shape: raise RuntimeError(err_msg) - shape = maximum.shape - minimum = minimum.expand(shape).clone() - elif minimum.ndimension(): - if shape is not None and shape != minimum.shape: + shape = high.shape + low = low.expand(shape).clone() + elif low.ndimension(): + if shape is not None and shape != low.shape: raise RuntimeError(err_msg) - shape = minimum.shape - maximum = maximum.expand(shape).clone() + shape = low.shape + high = high.expand(shape).clone() elif shape is None: raise RuntimeError(err_msg) else: - minimum = minimum.expand(shape).clone() - maximum = maximum.expand(shape).clone() + low = low.expand(shape).clone() + high = high.expand(shape).clone() - if minimum.numel() > maximum.numel(): - maximum = maximum.expand_as(minimum).clone() - elif maximum.numel() > minimum.numel(): - minimum = minimum.expand_as(maximum).clone() + if low.numel() > high.numel(): + high = high.expand_as(low).clone() + elif high.numel() > low.numel(): + low = low.expand_as(high).clone() if shape is None: - shape = minimum.shape + shape = low.shape else: if isinstance(shape, float): shape = torch.Size([shape]) elif not isinstance(shape, torch.Size): shape = torch.Size(shape) - shape_err_msg = ( - f"minimum and shape mismatch, got {minimum.shape} and {shape}" - ) - if len(minimum.shape) != len(shape): + shape_err_msg = f"low and shape mismatch, got {low.shape} and {shape}" + if len(low.shape) != len(shape): raise RuntimeError(shape_err_msg) - if not all(_s == _sa for _s, _sa in zip(shape, minimum.shape)): + if not all(_s == _sa for _s, _sa in zip(shape, low.shape)): raise RuntimeError(shape_err_msg) self.shape = shape - super().__init__( - shape, ContinuousBox(minimum, maximum), device, dtype, "continuous" - ) + super().__init__(shape, ContinuousBox(low, high), device, dtype, "continuous") def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -1353,8 +1580,8 @@ def expand(self, *shape): f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__( - minimum=self.space.minimum.expand(shape).clone(), - maximum=self.space.maximum.expand(shape).clone(), + low=self.space.low.expand(shape).clone(), + high=self.space.high.expand(shape).clone(), shape=shape, device=self.device, dtype=self.dtype, @@ -1366,15 +1593,15 @@ def squeeze(self, dim: int | None = None): return self if dim is None: - minimum = self.space.minimum.squeeze().clone() - maximum = self.space.maximum.squeeze().clone() + low = self.space.low.squeeze().clone() + high = self.space.high.squeeze().clone() else: - minimum = self.space.minimum.squeeze(dim).clone() - maximum = self.space.maximum.squeeze(dim).clone() + low = self.space.low.squeeze(dim).clone() + high = self.space.high.squeeze(dim).clone() return self.__class__( - minimum=minimum, - maximum=maximum, + low=low, + high=high, shape=shape, device=self.device, dtype=self.dtype, @@ -1383,13 +1610,37 @@ def squeeze(self, dim: int | None = None): def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) return self.__class__( - minimum=self.space.minimum.unsqueeze(dim).clone(), - maximum=self.space.maximum.unsqueeze(dim).clone(), + low=self.space.low.unsqueeze(dim).clone(), + high=self.space.high.unsqueeze(dim).clone(), shape=shape, device=self.device, dtype=self.dtype, ) + def unbind(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + low = self.space.low.unbind(dim) + high = self.space.high.unbind(dim) + return tuple( + self.__class__( + low=low[i], + high=high[i], + shape=shape, + device=self.device, + dtype=self.dtype, + ) + for i in range(self.shape[dim]) + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -1407,42 +1658,42 @@ def rand(self, shape=None) -> torch.Tensor: out[out < a] = a.expand_as(out)[out < a] return out else: - if self.space.maximum.dtype == torch.bool: - maxi = self.space.maximum.int() + if self.space.high.dtype == torch.bool: + maxi = self.space.high.int() else: - maxi = self.space.maximum - if self.space.minimum.dtype == torch.bool: - mini = self.space.minimum.int() + maxi = self.space.high + if self.space.low.dtype == torch.bool: + mini = self.space.low.int() else: - mini = self.space.minimum + mini = self.space.low interval = maxi - mini r = torch.rand(torch.Size([*shape, *self.shape]), device=interval.device) r = interval * r - r = self.space.minimum + r + r = self.space.low + r r = r.to(self.dtype).to(self.device) return r def _project(self, val: torch.Tensor) -> torch.Tensor: - minimum = self.space.minimum.to(val.device) - maximum = self.space.maximum.to(val.device) + low = self.space.low.to(val.device) + high = self.space.high.to(val.device) try: - val = val.clamp_(minimum.item(), maximum.item()) + val = val.clamp_(low.item(), high.item()) except ValueError: - minimum = minimum.expand_as(val) - maximum = maximum.expand_as(val) - val[val < minimum] = minimum[val < minimum] - val[val > maximum] = maximum[val > maximum] + low = low.expand_as(val) + high = high.expand_as(val) + val[val < low] = low[val < low] + val[val > high] = high[val > high] except RuntimeError: - minimum = minimum.expand_as(val) - maximum = maximum.expand_as(val) - val[val < minimum] = minimum[val < minimum] - val[val > maximum] = maximum[val > maximum] + low = low.expand_as(val) + high = high.expand_as(val) + val[val < low] = low[val < low] + val[val > high] = high[val > high] return val def is_in(self, val: torch.Tensor) -> bool: try: - return (val >= self.space.minimum.to(val.device)).all() and ( - val <= self.space.maximum.to(val.device) + return (val >= self.space.low.to(val.device)).all() and ( + val <= self.space.high.to(val.device) ).all() except RuntimeError as err: if "The size of tensor a" in str(err): @@ -1460,8 +1711,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__( - minimum=self.space.minimum.to(dest), - maximum=self.space.maximum.to(dest), + low=self.space.low.to(dest), + high=self.space.high.to(dest), shape=self.shape, device=dest_device, dtype=dest_dtype, @@ -1469,8 +1720,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: def clone(self) -> BoundedTensorSpec: return self.__class__( - minimum=self.space.minimum.clone(), - maximum=self.space.maximum.clone(), + low=self.space.low.clone(), + high=self.space.high.clone(), shape=self.shape, device=self.device, dtype=self.dtype, @@ -1486,8 +1737,8 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) # Expand is required as pytorch.tensor indexing return self.__class__( - minimum=self.space.minimum[idx].clone().expand(indexed_shape), - maximum=self.space.maximum[idx].clone().expand(indexed_shape), + low=self.space.low[idx].clone().expand(indexed_shape), + high=self.space.high[idx].clone().expand(indexed_shape), shape=indexed_shape, device=self.device, dtype=self.dtype, @@ -1585,6 +1836,24 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + return tuple( + self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + for i in range(self.shape[dim]) + ) + @dataclass(repr=False) class UnboundedDiscreteTensorSpec(TensorSpec): @@ -1648,10 +1917,10 @@ def clone(self) -> UnboundedDiscreteTensorSpec: def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - interval = self.space.maximum - self.space.minimum + interval = self.space.high - self.space.low r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) r = r * interval - r = self.space.minimum + r + r = self.space.low + r r = r.to(self.dtype) return r.to(self.device) @@ -1677,6 +1946,24 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + return tuple( + self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + for i in range(self.shape[dim]) + ) + @dataclass(repr=False) class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): @@ -1713,8 +2000,9 @@ def __init__( nvec: Sequence[int], shape: Optional[torch.Size] = None, device=None, - dtype=torch.long, + dtype=torch.bool, use_register=False, + mask: torch.Tensor | None = None, ): self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) @@ -1730,8 +2018,23 @@ def __init__( space = BoxList([DiscreteBox(n) for n in nvec]) self.use_register = use_register super(OneHotDiscreteTensorSpec, self).__init__( - shape, space, device, dtype, domain="discrete" + shape, + space, + device, + dtype, + domain="discrete", ) + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(*self.shape) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): @@ -1747,6 +2050,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: shape=self.shape, device=dest_device, dtype=dest_dtype, + mask=self.mask.to(dest) if self.mask is not None else None, ) def clone(self) -> MultiOneHotDiscreteTensorSpec: @@ -1755,6 +2059,27 @@ def clone(self) -> MultiOneHotDiscreteTensorSpec: shape=self.shape, device=self.device, dtype=self.dtype, + mask=self.mask.clone() if self.mask is not None else None, + ) + + def __eq__(self, other): + if not hasattr(other, "mask"): + return False + mask_equal = (self.mask is None and other.mask is None) or ( + isinstance(self.mask, torch.Tensor) + and isinstance(other.mask, torch.Tensor) + and (self.mask.shape == other.mask.shape) + and (self.mask == other.mask).all() + ) + return ( + type(self) == type(other) + and self.shape == other.shape + and self.space == other.space + and self.device == other.device + and self.dtype == other.dtype + and self.domain == other.domain + and self.use_register == other.use_register + and mask_equal ) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: @@ -1762,25 +2087,40 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: shape = self.shape[:-1] else: shape = torch.Size([*shape, *self.shape[:-1]]) - - x = torch.cat( - [ - torch.nn.functional.one_hot( - torch.randint( - space.n, - ( - *shape, - 1, + mask = self.mask + + if mask is None: + x = torch.cat( + [ + torch.nn.functional.one_hot( + torch.randint( + space.n, + ( + *shape, + 1, + ), + device=self.device, ), - device=self.device, - ), - space.n, - ).to(torch.long) - for space in self.space - ], - -1, - ).squeeze(-2) - return x + space.n, + ).to(self.dtype) + for space in self.space + ], + -1, + ).squeeze(-2) + return x + mask = mask.expand(*shape, mask.shape[-1]) + mask_splits = torch.split(mask, [space.n for space in self.space], -1) + out = [] + for _mask in mask_splits: + if mask.ndim > 2: + mask_flat = torch.flatten(_mask, 0, -2) + else: + mask_flat = _mask + shape_out = _mask.shape[:-1] + m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + m = torch.nn.functional.one_hot(m, _mask.shape[-1]).to(self.dtype) + out.append(m) + return torch.cat(out, -1) def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device: bool = False @@ -1802,7 +2142,7 @@ def encode( v, space, ignore_device=ignore_device ) ) - return torch.cat(x, -1) + return torch.cat(x, -1).reshape(self.shape) def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]: split_sizes = [space.n for space in self.space] @@ -1830,16 +2170,41 @@ def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) if vals is None: return False - return all( - super(MultiOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals - ) + return all(spec.is_in(val) for val, spec in zip(vals, self._split_self())) def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) - return torch.cat([super()._project(_val) for _val in vals], -1) - - def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: - """Converts a given one-hot tensor in categorical format. + return torch.cat( + [spec._project(val) for val, spec in zip(vals, self._split_self())], -1 + ) + + def _split_self(self): + result = [] + device = self.device + dtype = self.dtype + use_register = self.use_register + mask = ( + self.mask.split([space.n for space in self.space], -1) + if self.mask is not None + else [None] * len(self.space) + ) + for _mask, space in zip(mask, self.space): + n = space.n + shape = self.shape[:-1] + (n,) + result.append( + OneHotDiscreteTensorSpec( + n=n, + shape=shape, + device=device, + dtype=dtype, + use_register=use_register, + mask=_mask, + ) + ) + return result + + def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """Converts a given one-hot tensor in categorical format. Args: val (torch.Tensor, optional): One-hot tensor to convert in categorical format. @@ -1855,15 +2220,15 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: if safe: self.assert_is_in(val) vals = self._split(val) - return torch.stack([val.argmax(-1) for val in vals], -1) + return torch.stack([val.long().argmax(-1) for val in vals], -1) def to_categorical_spec(self) -> MultiDiscreteTensorSpec: """Converts the spec to the equivalent categorical spec.""" return MultiDiscreteTensorSpec( [_space.n for _space in self.space], device=self.device, - dtype=self.dtype, shape=[*self.shape[:-1], len(self.space)], + mask=self.mask, ) def expand(self, *shape): @@ -1879,31 +2244,66 @@ def expand(self, *shape): f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) + mask = self.mask.expand(shape) if self.mask is not None else None return self.__class__( - nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype + nvec=nvecs, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def squeeze(self, dim=None): if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): - raise ValueError( - "Final dimension of MultiOneHotDiscreteTensorSpec must remain unchanged" - ) + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _squeezed_shape(self.shape, dim) if shape is None: return self + mask = self.mask.reshape(shape) if self.mask is not None else None return self.__class__( - nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + nvec=self.nvec, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def unsqueeze(self, dim: int): if dim in (len(self.shape), -1): - raise ValueError( - "Final dimension of MultiOneHotDiscreteTensorSpec must remain unchanged" - ) + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _unsqueezed_shape(self.shape, dim) + mask = self.mask.reshape(shape) if self.mask is not None else None return self.__class__( - nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask + ) + + def unbind(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + mask = self.mask + if mask is None: + mask = (None,) * self.shape[dim] + else: + mask = mask.unbind(dim) + + return tuple( + self.__class__( + nvec=self.nvec, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask[i], + ) + for i in range(self.shape[dim]) ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): @@ -1956,34 +2356,71 @@ class DiscreteTensorSpec(TensorSpec): def __init__( self, n: int, - shape: Optional[torch.Size] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, + shape: torch.Size | None = None, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype = torch.long, + mask: torch.Tensor | None = None, ): if shape is None: shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) super().__init__(shape, space, device, dtype, domain="discrete") + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(*self.shape, self.space.n) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - return torch.randint( - 0, - self.space.n, - torch.Size([*shape, *self.shape]), - device=self.device, - dtype=self.dtype, - ) + if self.mask is None: + return torch.randint( + 0, + self.space.n, + torch.Size([*shape, *self.shape]), + device=self.device, + dtype=self.dtype, + ) + mask = self.mask + mask = mask.expand(*shape, *mask.shape) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + return out def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (torch.int, torch.long): val = torch.round(val) - return val.clamp_(min=0, max=self.space.n - 1) + if self.mask is None: + return val.clamp_(min=0, max=self.space.n - 1) + shape = self.mask.shape + shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + mask_expand = self.mask.expand(shape) + gathered = mask_expand.gather(-1, val.unsqueeze(-1)) + oob = ~gathered.all(-1) + new_val = torch.multinomial(mask_expand[oob].float(), 1).squeeze(-1) + val = torch.masked_scatter(val, oob, new_val) + return val def is_in(self, val: torch.Tensor) -> bool: - return (0 <= val).all() and (val < self.space.n).all() + if self.mask is None: + return (0 <= val).all() and (val < self.space.n).all() + shape = self.mask.shape + shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + mask_expand = self.mask.expand(shape) + gathered = mask_expand.gather(-1, val.unsqueeze(-1)) + return gathered.all() def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" @@ -1996,6 +2433,14 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) def __eq__(self, other): + if not hasattr(other, "mask"): + return False + mask_equal = (self.mask is None and other.mask is None) or ( + isinstance(self.mask, torch.Tensor) + and isinstance(other.mask, torch.Tensor) + and (self.mask.shape == other.mask.shape) + and (self.mask == other.mask).all() + ) return ( type(self) == type(other) and self.shape == other.shape @@ -2003,9 +2448,14 @@ def __eq__(self, other): and self.device == other.device and self.dtype == other.dtype and self.domain == other.domain + and mask_equal ) - def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: + def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: + if safe is None: + safe = _CHECK_SPEC_ENCODE + # if not val.shape and not safe: + # return val.item() return super().to_numpy(val, safe) def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: @@ -2030,7 +2480,9 @@ def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" shape = [*self.shape, self.space.n] return OneHotDiscreteTensorSpec( - n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + n=self.space.n, + shape=shape, + device=self.device, ) def expand(self, *shape): @@ -2051,16 +2503,56 @@ def expand(self, *shape): def squeeze(self, dim=None): shape = _squeezed_shape(self.shape, dim) + mask = self.mask + if mask is not None: + mask = mask.view(*shape, mask.shape[-1]) + if shape is None: return self return self.__class__( - n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + n=self.space.n, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) + mask = self.mask + if mask is not None: + mask = mask.view(*shape, mask.shape[-1]) return self.__class__( - n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + n=self.space.n, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, + ) + + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + mask = self.mask + if mask is None: + mask = (None,) * self.shape[dim] + else: + mask = mask.unbind(dim) + return tuple( + self.__class__( + n=self.space.n, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask[i], + ) + for i in range(self.shape[dim]) ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: @@ -2082,6 +2574,7 @@ def clone(self) -> DiscreteTensorSpec: shape=self.shape, device=self.device, dtype=self.dtype, + mask=self.mask.clone() if self.mask is not None else None, ) @@ -2132,7 +2625,7 @@ def expand(self, *shape): f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__( - n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype ) def squeeze(self, dim=None): @@ -2140,13 +2633,32 @@ def squeeze(self, dim=None): if shape is None: return self return self.__class__( - n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype ) def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) return self.__class__( - n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + + def unbind(self, dim: int): + if dim in (len(self.shape) - 1, -1): + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") + + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + return tuple( + self.__class__( + n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + for i in range(self.shape[dim]) ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: @@ -2213,6 +2725,7 @@ def __init__( shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, + mask: torch.Tensor | None = None, ): if not isinstance(nvec, torch.Tensor): nvec = torch.tensor(nvec) @@ -2236,6 +2749,17 @@ def __init__( super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" ) + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(*self.shape[:-1], mask.shape[-1]) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): @@ -2246,8 +2770,32 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self + mask = self.mask.to(dest) if self.mask is not None else None return self.__class__( - n=self.nvec.to(dest), shape=None, device=dest_device, dtype=dest_dtype + n=self.nvec.to(dest), + shape=None, + device=dest_device, + dtype=dest_dtype, + mask=mask, + ) + + def __eq__(self, other): + if not hasattr(other, "mask"): + return False + mask_equal = (self.mask is None and other.mask is None) or ( + isinstance(self.mask, torch.Tensor) + and isinstance(other.mask, torch.Tensor) + and (self.mask.shape == other.mask.shape) + and (self.mask == other.mask).all() + ) + return ( + type(self) == type(other) + and self.shape == other.shape + and self.space == other.space + and self.device == other.device + and self.dtype == other.dtype + and self.domain == other.domain + and mask_equal ) def clone(self) -> MultiDiscreteTensorSpec: @@ -2256,6 +2804,7 @@ def clone(self) -> MultiDiscreteTensorSpec: shape=None, device=self.device, dtype=self.dtype, + mask=self.mask.clone() if self.mask is not None else None, ) def _rand(self, space: Box, shape: torch.Size, i: int): @@ -2276,6 +2825,9 @@ def _rand(self, space: Box, shape: torch.Size, i: int): return torch.stack(x, -1) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: + if self.mask is not None: + splits = self._split_self() + return torch.stack([split.rand(shape) for split in splits], -1) if shape is None: shape = self.shape[:-1] else: @@ -2288,7 +2840,42 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: x = x.squeeze(-1) return x + def _split_self(self): + result = [] + device = self.device + dtype = self.dtype + nvec = self.nvec + if nvec.ndim > 1: + nvec = torch.flatten(nvec, 0, -2)[0] + if (self.nvec != nvec).any(): + raise ValueError( + f"Only homogeneous MultiDiscrete specs can be masked, got nvec={self.nvec}." + ) + nvec = nvec.tolist() + mask = ( + self.mask.split(nvec, -1) + if self.mask is not None + else [None] * len(self.space) + ) + for n, _mask in zip(nvec, mask): + shape = self.shape[:-1] + result.append( + DiscreteTensorSpec( + n=n, shape=shape, device=device, dtype=dtype, mask=_mask + ) + ) + return result + def _project(self, val: torch.Tensor) -> torch.Tensor: + if self.mask is not None: + return torch.stack( + [ + spec._project(_val) + for (_val, spec) in zip(val.unbind(-1), self._split_self()) + ], + -1, + ) + val_is_scalar = val.ndim < 1 if val_is_scalar: val = val.unsqueeze(0) @@ -2301,6 +2888,12 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: + if self.mask is not None: + return all( + spec.is_in(_val) + for (_val, spec) in zip(val.unbind(-1), self._split_self()) + ) + if val.ndim < 1: val = val.unsqueeze(0) val_have_wrong_dim = ( @@ -2351,8 +2944,8 @@ def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: return MultiOneHotDiscreteTensorSpec( nvec, device=self.device, - dtype=self.dtype, shape=[*self.shape[:-1], sum(nvec)], + mask=self.mask, ) def expand(self, *shape): @@ -2367,16 +2960,22 @@ def expand(self, *shape): f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) + mask = ( + self.mask.expand(*shape, self.mask.shape[-1]) + if self.mask is not None + else None + ) return self.__class__( - nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + nvec=self.nvec, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def squeeze(self, dim: int | None = None): if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): - raise ValueError( - "Final dimension of MultiDiscreteTensorSpec must remain unchanged" - ) - + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _squeezed_shape(self.shape, dim) if shape is None: return self @@ -2385,20 +2984,55 @@ def squeeze(self, dim: int | None = None): nvec = self.nvec.squeeze() else: nvec = self.nvec.squeeze(dim) - + mask = self.mask + if mask is not None: + mask = mask.view(*shape[:-1], mask.shape[-1]) return self.__class__( - nvec=nvec, shape=shape, device=self.device, dtype=self.dtype + nvec=nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask ) def unsqueeze(self, dim: int): if dim in (len(self.shape), -1): - raise ValueError( - "Final dimension of MultiDiscreteTensorSpec must remain unchanged" - ) + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _unsqueezed_shape(self.shape, dim) nvec = self.nvec.unsqueeze(dim) + mask = self.mask + if mask is not None: + mask = mask.view(*shape[:-1], mask.shape[-1]) return self.__class__( - nvec=nvec, shape=shape, device=self.device, dtype=self.dtype + nvec=nvec, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, + ) + + def unbind(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError(f"Final dimension of {type(self)} must remain unchanged") + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + mask = self.mask + nvec = self.nvec.unbind(dim) + if mask is not None: + mask = mask.unbind(dim) + else: + mask = (None,) * self.shape[dim] + return tuple( + self.__class__( + nvec=nvec[i], + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask[i], + ) + for i in range(self.shape[dim]) ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): @@ -2512,6 +3146,10 @@ def shape(self, value: torch.Size): ) self._shape = torch.Size(value) + def is_empty(self): + """Whether the composite spec contains specs or not.""" + return len(self._specs) == 0 + @property def ndim(self): return self.ndimension() @@ -2756,17 +3394,18 @@ def type_check( value = {selected_keys: value} selected_keys = [selected_keys] - for _key in self: + for _key in self.keys(): if self[_key] is not None and ( selected_keys is None or _key in selected_keys ): self._specs[_key].type_check(value[_key], _key) def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - for (key, item) in self._specs.items(): - if item is None: + for key, item in self._specs.items(): + if item is None or (isinstance(item, CompositeSpec) and item.is_empty()): continue - if not item.is_in(val.get(key)): + val_item = val.get(key) + if not item.is_in(val_item): return False return True @@ -2908,6 +3547,18 @@ def clone(self) -> CompositeSpec: shape=self.shape, ) + def empty(self): + """Create a spec like self, but with no entries.""" + try: + device = self.device + except RuntimeError: + device = self._device + return self.__class__( + {}, + device=device, + shape=self.shape, + ) + def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} @@ -3017,7 +3668,7 @@ def squeeze(self, dim: int | None = None): def unsqueeze(self, dim: int): if dim < 0: - dim += len(self.shape) + dim += len(self.shape) + 1 shape = _unsqueezed_shape(self.shape, dim) @@ -3035,6 +3686,25 @@ def unsqueeze(self, dim: int): device=device, ) + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = (s for i, s in enumerate(self.shape) if i != dim) + unbound_vals = {key: val.unbind(dim) for key, val in self.items()} + return tuple( + self.__class__( + {key: val[i] for key, val in unbound_vals.items()}, + shape=shape, + device=self.device, + ) + for i in range(self.shape[dim]) + ) + def lock_(self, recurse=False): """Locks the CompositeSpec and prevents modification of its content. @@ -3115,11 +3785,28 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): """ - def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: - pass + def update(self, dict) -> None: + for key, item in dict.items(): + if key in self.keys() and isinstance( + item, (Dict, CompositeSpec, LazyStackedCompositeSpec) + ): + for spec, sub_item in zip(self._specs, item.unbind(self.dim)): + spec[key].update(sub_item) + continue + self[key] = item + return self def __eq__(self, other): - pass + if not isinstance(other, LazyStackedCompositeSpec): + return False + if len(self._specs) != len(other._specs): + return False + if self.stack_dim != other.stack_dim: + return False + for _spec1, _spec2 in zip(self._specs, other._specs): + if _spec1 != _spec2: + return False + return True def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: if safe is None: @@ -3135,14 +3822,22 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} def __len__(self): - pass + return self.shape[0] - def values(self): - for key in self.keys(): + def values( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield self[key] - def items(self): - for key in self.keys(): + def items( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield key, self[key] def keys( @@ -3150,47 +3845,125 @@ def keys( include_nested: bool = False, leaves_only: bool = False, ) -> KeysView: - return self._specs[0].keys( + keys = self._specs[0].keys( include_nested=include_nested, leaves_only=leaves_only ) + keys = set(keys) + for spec in self._specs[1:]: + keys = keys.intersection(spec.keys(include_nested, leaves_only)) + return sorted(keys, key=str) def project(self, val: TensorDictBase) -> TensorDictBase: - pass - - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - pass + vals = [] + for spec, subval in zip(self._specs, val.unbind(self.dim)): + if not spec.is_in(subval): + vals.append(spec.project(subval)) + else: + vals.append(subval) + res = torch.stack(vals, dim=self.dim) + if not isinstance(val, LazyStackedTensorDict): + res = res.to_tensordict() + return res def type_check( self, value: Union[torch.Tensor, TensorDictBase], - selected_keys: Union[str, Optional[Sequence[str]]] = None, + selected_keys: Union[NestedKey, Optional[Sequence[NestedKey]]] = None, ): - pass + if selected_keys is None: + if isinstance(value, torch.Tensor): + raise ValueError( + "value must be of type TensorDictBase when key is None" + ) + for spec, subvalue in zip(self._specs, value.unbind(self.dim)): + spec.type_check(subvalue) + else: + if isinstance(value, torch.Tensor) and isinstance(selected_keys, str): + value = {selected_keys: value} + selected_keys = [selected_keys] + for _key in self.keys(): + if self[_key] is not None and _key in selected_keys: + self[_key].type_check(value[_key], _key) def __repr__(self) -> str: sub_str = ",\n".join( [indent(f"{k}: {repr(item)}", 4 * " ") for k, item in self.items()] ) - device_str = f"device={self._specs[0].device}" - shape_str = f"shape={self.shape}" - sub_str = ", ".join([sub_str, device_str, shape_str]) - return ( - f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" + sub_str = indent(f"fields={{\n{', '.join([sub_str])}}}", 4 * " ") + exclusive_key_str = self.repr_exclusive_keys() + device_str = indent(f"device={self._specs[0].device}", 4 * " ") + shape_str = indent(f"shape={self.shape}", 4 * " ") + stack_dim = indent(f"stack_dim={self.dim}", 4 * " ") + string = ",\n".join( + [sub_str, exclusive_key_str, device_str, shape_str, stack_dim] + ) + return f"LazyStackedCompositeSpec(\n{string})" + + def repr_exclusive_keys(self): + keys = set(self.keys()) + exclusive_keys = [ + ",\n".join( + [ + indent(f"{k}: {repr(spec[k])}", 4 * " ") + for k in spec.keys() + if k not in keys + ] + ) + for spec in self._specs + ] + exclusive_key_str = ",\n".join( + [ + indent(f"{i} ->\n{line}", 4 * " ") + for i, line in enumerate(exclusive_keys) + if line != "" + ] ) - def encode( - self, vals: Dict[str, Any], ignore_device: bool = False - ) -> Dict[str, torch.Tensor]: - pass + return indent(f"exclusive_fields={{\n{exclusive_key_str}}}", 4 * " ") + + def is_in(self, val) -> bool: + for spec, subval in zip(self._specs, val.unbind(self.dim)): + if not spec.is_in(subval): + return False + return True - def __delitem__(self, key): - pass + def __delitem__(self, key: NestedKey): + """Deletes a key from the stacked composite spec. + + This method will be executed if the key is present in at least one of the stacked specs, + otherwise it will raise an error. + + Args: + key (NestedKey): the key to delete. + """ + at_least_one_deletion = False + for spec in self._specs: + try: + del spec[key] + at_least_one_deletion = True + except KeyError: + continue + if not at_least_one_deletion: + raise KeyError( + f"Key {key} must be present in at least one of the stacked specs" + ) + return self def __iter__(self): - pass + for k in self.keys(): + yield self[k] - def __setitem__(self, key, value): - pass + def __setitem__(self, key: NestedKey, value): + key = unravel_key(key) + is_key = isinstance(key, str) or ( + isinstance(key, tuple) and all(isinstance(_item, str) for _item in key) + ) + if is_key: + self.set(key, value) + else: + raise ValueError( + f"{self.__class__} expects str or tuple of str as key to set values " + ) @property def device(self) -> DEVICE_TYPING: @@ -3204,15 +3977,54 @@ def ndimension(self): return len(self.shape) def set(self, name, spec): - if spec is not None: - shape = spec.shape - if shape[: self.ndim] != self.shape: - raise ValueError( - "The shape of the spec and the CompositeSpec mismatch: the first " - f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " - f"CompositeSpec.shape={self.shape}." + for sub_spec, sub_item in zip(self._specs, spec.unbind(self.dim)): + sub_spec[name] = sub_item + + @property + def shape(self): + shape = list(self._specs[0].shape) + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self._specs)) + return torch.Size(shape) + + def expand(self, *shape): + if len(shape) == 1 and not isinstance(shape[0], (int,)): + return self.expand(*shape[0]) + expand_shape = shape[: -len(self.shape)] + existing_shape = self.shape + shape_check = shape[-len(self.shape) :] + for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + if size1 != size2 and size1 != 1: + raise RuntimeError( + f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" ) - self._specs[name] = spec + elif size1 != size2 and size1 == 1 and _i == self.dim: + # if we're expanding along the stack dim we just need to clone the existing spec + return torch.stack( + [self._specs[0].clone() for _ in range(size2)], self.dim + ).expand(*shape) + if _i != len(self.shape) - 1: + raise RuntimeError( + f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + ) + # remove the stack dim from the expanded shape, which we know to match + unstack_shape = list(expand_shape) + [ + s for i, s in enumerate(shape_check) if i != self.dim + ] + return torch.stack( + [spec.expand(unstack_shape) for spec in self._specs], + self.dim + len(expand_shape), + ) + + def empty(self): + return torch.stack([spec.empty() for spec in self._specs], dim=self.stack_dim) + + def encode( + self, vals: Dict[str, Any], ignore_device: bool = False + ) -> Dict[str, torch.Tensor]: + raise NOT_IMPLEMENTED_ERROR # for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @@ -3228,14 +4040,19 @@ def _stack_specs(list_of_spec, dim, out=None): spec0 = list_of_spec[0] if isinstance(spec0, TensorSpec): device = spec0.device + all_equal = True for spec in list_of_spec[1:]: - if not isinstance(spec, TensorSpec): + if not isinstance(spec, spec0.__class__): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of specs in the list." ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + if spec.dtype != spec0.dtype: + raise RuntimeError(f"Dtypes differ, got {spec0.dtype} and {spec.dtype}") + if spec.ndim != spec0.ndim: + raise RuntimeError(f"Ndims differ, got {spec0.ndim} and {spec.ndim}") all_equal = all_equal and spec == spec0 if all_equal: shape = list(spec0.shape) @@ -3269,6 +4086,8 @@ def _stack_composite_specs(list_of_spec, dim, out=None): ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + if spec.shape != spec0.shape: + raise RuntimeError(f"Shapes differ, got {spec.shape} and {spec0.shape}") all_equal = all_equal and spec == spec0 if all_equal: shape = list(spec0.shape) diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index 76a5f5d7f61..5db57ec6ba3 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -8,8 +8,21 @@ import numpy as np import torch + from torch import Tensor +from torchrl.data.tensor_specs import ( + BinaryDiscreteTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + LazyStackedCompositeSpec, + LazyStackedTensorSpec, + MultiDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, + OneHotDiscreteTensorSpec, + TensorSpec, +) + numpy_to_torch_dtype_dict = { np.dtype("bool"): torch.bool, np.dtype("uint8"): torch.uint8, @@ -35,6 +48,180 @@ INDEX_TYPING = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]] +ACTION_SPACE_MAP = { + OneHotDiscreteTensorSpec: "one_hot", + MultiOneHotDiscreteTensorSpec: "mult_one_hot", + BinaryDiscreteTensorSpec: "binary", + DiscreteTensorSpec: "categorical", + "one_hot": "one_hot", + "one-hot": "one_hot", + "mult_one_hot": "mult_one_hot", + "mult-one-hot": "mult_one_hot", + "multi_one_hot": "mult_one_hot", + "multi-one-hot": "mult_one_hot", + "binary": "binary", + "categorical": "categorical", + MultiDiscreteTensorSpec: "multi_categorical", + "multi_categorical": "multi_categorical", + "multi-categorical": "multi_categorical", + "multi_discrete": "multi_categorical", + "multi-discrete": "multi_categorical", +} + + +def consolidate_spec( + spec: CompositeSpec, + recurse_through_entries: bool = True, + recurse_through_stack: bool = True, +): + """Given a TensorSpec, removes exclusive keys by adding 0 shaped specs. + + Args: + spec (CompositeSpec): the spec to be consolidated. + recurse_through_entries (bool): if True, call the function recursively on all entries of the spec. + Default is True. + recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively + on all specs in its list. Default is True. + + """ + spec = spec.clone() + + if not isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)): + return spec + + if isinstance(spec, LazyStackedCompositeSpec): + keys = set(spec.keys()) # shared keys + exclusive_keys_per_spec = [ + set() for _ in range(len(spec._specs)) + ] # list of exclusive keys per td + exclusive_keys_examples = ( + {} + ) # map of all exclusive keys to a list of their values + for spec_index in range(len(spec._specs)): # gather all exclusive keys + sub_spec = spec._specs[spec_index] + if recurse_through_stack: + sub_spec = consolidate_spec( + sub_spec, recurse_through_entries, recurse_through_stack + ) + spec._specs[spec_index] = sub_spec + for sub_spec_key in sub_spec.keys(): + if sub_spec_key not in keys: # exclusive key + exclusive_keys_per_spec[spec_index].add(sub_spec_key) + value = sub_spec[sub_spec_key] + if sub_spec_key in exclusive_keys_examples: + exclusive_keys_examples[sub_spec_key].append(value) + else: + exclusive_keys_examples.update({sub_spec_key: [value]}) + + for sub_spec, exclusive_keys in zip( + spec._specs, exclusive_keys_per_spec + ): # add missing exclusive entries + for exclusive_key in set(exclusive_keys_examples.keys()).difference( + exclusive_keys + ): + exclusive_keys_example_list = exclusive_keys_examples[exclusive_key] + sub_spec.set( + exclusive_key, + _empty_like_spec(exclusive_keys_example_list, sub_spec.shape), + ) + + if recurse_through_entries: + for key, value in spec.items(): + if isinstance(value, (CompositeSpec, LazyStackedCompositeSpec)): + spec.set( + key, + consolidate_spec( + value, recurse_through_entries, recurse_through_stack + ), + ) + return spec + + +def _empty_like_spec(specs: List[TensorSpec], shape): + for spec in specs[1:]: + if spec.__class__ != specs[0].__class__: + raise ValueError( + "Found same key in lazy specs corresponding to entries with different classes" + ) + spec = specs[0] + if isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)): + # the exclusive key has values which are CompositeSpecs -> + # we create an empty composite spec with same batch size + return spec.empty() + elif isinstance(spec, LazyStackedTensorSpec): + # the exclusive key has values which are LazyStackedTensorSpecs -> + # we create a LazyStackedTensorSpec with the same shape (aka same -1s) as the first in the list. + # this will not add any new -1s when they are stacked + shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :]) + return LazyStackedTensorSpec( + *[_empty_like_spec(spec._specs, shape) for _ in spec._specs], + dim=spec.stack_dim, + ) + else: + # the exclusive key has values which are TensorSpecs -> + # if the shapes of the values are all the same, we create a TensorSpec with leading shape `shape` and following dims 0 (having the same ndims as the values) + # if the shapes of the values differ, we create a TensorSpec with 0 size in the differing dims + spec_shape = list(spec.shape) + + for dim_index in range(len(spec_shape)): + hetero_dim = False + for sub_spec in specs: + if sub_spec.shape[dim_index] != spec.shape[dim_index]: + hetero_dim = True + break + if hetero_dim: + spec_shape[dim_index] = 0 + + if 0 not in spec_shape: # the values have all same shape + spec_shape = [ + dim if i < len(shape) else 0 for i, dim in enumerate(spec_shape) + ] + + spec = spec[(0,) * len(spec.shape)] + spec = spec.expand(spec_shape) + + return spec + + +def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True): + """Given a TensorSpec, returns true if there are no exclusive keys. + + Args: + spec (TensorSpec): the spec to check + recurse (bool): if True, check recursively in nested specs. Default is True. + """ + if isinstance(spec, LazyStackedCompositeSpec): + keys = set(spec.keys()) + for inner_td in spec._specs: + if recurse and not check_no_exclusive_keys(inner_td): + return False + if set(inner_td.keys()) != keys: + return False + elif isinstance(spec, CompositeSpec) and recurse: + for value in spec.values(): + if not check_no_exclusive_keys(value): + return False + else: + return True + return True + + +def contains_lazy_spec(spec: TensorSpec) -> bool: + """Returns true if a spec contains lazy stacked specs. + + Args: + spec (TensorSpec): the spec to check + + """ + if isinstance(spec, (LazyStackedTensorSpec, LazyStackedCompositeSpec)): + return True + elif isinstance(spec, CompositeSpec): + for inner_spec in spec.values(): + if contains_lazy_spec(inner_spec): + return True + return False + + class CloudpickleWrapper(object): """A wrapper for functions that allow for serialization in multiprocessed settings.""" @@ -58,6 +245,81 @@ def __setstate__(self, ob: bytes): self.fn, self.kwargs = pickle.loads(ob) def __call__(self, *args, **kwargs) -> Any: - kwargs = {k: item for k, item in kwargs.items()} kwargs.update(self.kwargs) - return self.fn(**kwargs) + return self.fn(*args, **kwargs) + + +def _process_action_space_spec(action_space, spec): + original_spec = spec + composite_spec = False + if isinstance(spec, CompositeSpec): + # this will break whenever our action is more complex than a single tensor + try: + if "action" in spec.keys(): + _key = "action" + else: + # the first key is the action + for _key in spec.keys(True, True): + if isinstance(_key, tuple) and _key[-1] == "action": + break + else: + raise KeyError + spec = spec[_key] + composite_spec = True + except KeyError: + raise KeyError( + "action could not be found in the spec. Make sure " + "you pass a spec that is either a native action spec or a composite action spec " + "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." + ) + if action_space is not None: + if isinstance(action_space, CompositeSpec): + raise ValueError("action_space cannot be of type CompositeSpec.") + if ( + spec is not None + and isinstance(action_space, TensorSpec) + and action_space is not spec + ): + raise ValueError( + "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." + ) + if isinstance(action_space, TensorSpec): + spec = action_space + action_space = _find_action_space(action_space) + # check that the spec and action_space match + if spec is not None and _find_action_space(spec) != action_space: + raise ValueError( + f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}." + ) + elif spec is not None: + action_space = _find_action_space(spec) + else: + raise ValueError( + "Neither action_space nor spec was defined. The action space cannot be inferred." + ) + if composite_spec: + spec = original_spec + return action_space, spec + + +def _find_action_space(action_space): + if isinstance(action_space, TensorSpec): + if isinstance(action_space, CompositeSpec): + if "action" in action_space.keys(): + _key = "action" + else: + # the first key is the action + for _key in action_space.keys(True, True): + if isinstance(_key, tuple) and _key[-1] == "action": + break + else: + raise KeyError + action_space = action_space[_key] + action_space = type(action_space) + try: + action_space = ACTION_SPACE_MAP[action_space] + except KeyError: + raise ValueError( + f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}." + ) + return action_space diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 827479bf9c0..461a47aa7da 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -3,18 +3,49 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict from .env_creator import EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv +from .libs import ( + BraxEnv, + BraxWrapper, + DMControlEnv, + DMControlWrapper, + gym_backend, + GymEnv, + GymWrapper, + HabitatEnv, + IsaacGymEnv, + IsaacGymWrapper, + JumanjiEnv, + JumanjiWrapper, + MultiThreadedEnv, + MultiThreadedEnvWrapper, + OpenMLEnv, + PettingZooEnv, + PettingZooWrapper, + RoboHiveEnv, + set_gym_backend, + SMACv2Env, + SMACv2Wrapper, + VmasEnv, + VmasWrapper, +) from .model_based import ModelBasedEnvBase from .transforms import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, CenterCrop, + ClipTransform, Compose, + DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, + DTypeCastTransform, + EndOfLifeTransform, ExcludeTransform, FiniteTensorDictCheck, FlattenObservation, @@ -26,6 +57,7 @@ NoopResetEnv, ObservationNorm, ObservationTransform, + PermuteTransform, PinMemoryTransform, R3MTransform, RandomCropTensorDict, @@ -52,12 +84,13 @@ ) from .utils import ( check_env_specs, + check_marl_grouping, exploration_mode, exploration_type, ExplorationType, make_composite_from_td, + MarlGroupMapType, set_exploration_mode, set_exploration_type, step_mdp, ) -from .vec_env import MultiThreadedEnv, ParallelEnv, SerialEnv diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/batched_envs.py similarity index 60% rename from torchrl/envs/vec_env.py rename to torchrl/envs/batched_envs.py index 108066522e1..aa1f256c070 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/batched_envs.py @@ -5,38 +5,35 @@ from __future__ import annotations -import importlib -import logging import os from collections import OrderedDict from copy import deepcopy from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from warnings import warn -import numpy as np import torch -from tensordict import TensorDict, unravel_key -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict import TensorDict +from tensordict._tensordict import _unravel_key_to_tuple, unravel_keys from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp -from torchrl._utils import _check_for_faulty_process, VERBOSE -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) -from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import _EnvWrapper, EnvBase +from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE +from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING +from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import get_env_metadata -from torchrl.envs.utils import _set_single_key, _sort_keys +from torchrl.envs.utils import ( + _aggregate_resets, + _set_single_key, + _sort_keys, + clear_mpi_env_vars, +) -_has_envpool = importlib.util.find_spec("envpool") +# legacy +from .libs.envpool import MultiThreadedEnv, MultiThreadedEnvWrapper # noqa: F401 def _check_start(fun): @@ -117,7 +114,6 @@ class _BatchedEnv(EnvBase): if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were needed, which comes with a slight compute overhead; create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created; - pin_memory (bool): if True and device is "cpu", calls :obj:`pin_memory` on the tensordicts when created. share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy stack is returned. default = None (False if single task); @@ -130,9 +126,15 @@ class _BatchedEnv(EnvBase): It is assumed that all environments will run on the same device as a common shared tensordict will be used to pass data from process to process. The device can be changed after instantiation using :obj:`env.to(device)`. - allow_step_when_done (bool, optional): if ``True``, batched environments can - execute steps after a done state is encountered. - Defaults to ``False``. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + This parameter has no effect for the :class:`~SerialEnv` class. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + This parameter has no effect for the :class:`~SerialEnv` class. """ @@ -156,6 +158,8 @@ def __init__( policy_proof: Optional[Callable] = None, device: Optional[DEVICE_TYPING] = None, allow_step_when_done: bool = False, + num_threads: int = None, + num_sub_threads: int = 1, ): if device is not None: raise ValueError( @@ -166,6 +170,10 @@ def __init__( super().__init__(device=None) self.is_closed = True + if num_threads is None: + num_threads = num_workers + 1 # 1 more thread for this proc + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads self._cache_in_keys = None self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1) @@ -195,10 +203,15 @@ def __init__( self.create_env_fn = create_env_fn self.create_env_kwargs = create_env_kwargs self.pin_memory = pin_memory + if pin_memory: + raise ValueError("pin_memory for batched envs is deprecated") + self.share_individual_td = bool(share_individual_td) self._share_memory = shared_memory self._memmap = memmap self.allow_step_when_done = allow_step_when_done + if allow_step_when_done: + raise ValueError("allow_step_when_done is deprecated") if self._share_memory and self._memmap: raise RuntimeError( "memmap and shared memory are mutually exclusive features." @@ -264,11 +277,11 @@ def _set_properties(self): input_spec = meta_data.specs["input_spec"].to(device) output_spec = meta_data.specs["output_spec"].to(device) - self.action_spec = input_spec["_action_spec"] - self.state_spec = input_spec["_state_spec"] - self.observation_spec = output_spec["_observation_spec"] - self.reward_spec = output_spec["_reward_spec"] - self.done_spec = output_spec["_done_spec"] + self.action_spec = input_spec["full_action_spec"] + self.state_spec = input_spec["full_state_spec"] + self.observation_spec = output_spec["full_observation_spec"] + self.reward_spec = output_spec["full_reward_spec"] + self.done_spec = output_spec["full_done_spec"] self._dummy_env_str = meta_data.env_str self._env_tensordict = meta_data.tensordict @@ -287,18 +300,19 @@ def _set_properties(self): output_spec.append(md.specs["output_spec"]) output_spec = torch.stack(output_spec, 0) - self.action_spec = input_spec["_action_spec"] - self.state_spec = input_spec["_state_spec"] + self.action_spec = input_spec["full_action_spec"] + self.state_spec = input_spec["full_state_spec"] - self.observation_spec = output_spec["_observation_spec"] - self.reward_spec = output_spec["_reward_spec"] - self.done_spec = output_spec["_done_spec"] + self.observation_spec = output_spec["full_observation_spec"] + self.reward_spec = output_spec["full_reward_spec"] + self.done_spec = output_spec["full_done_spec"] self._dummy_env_str = str(meta_data[0]) self._env_tensordict = torch.stack( [meta_data.tensordict for meta_data in meta_data], 0 ) self._batch_locked = meta_data[0].batch_locked + self.has_lazy_inputs = contains_lazy_spec(self.input_spec) def state_dict(self) -> OrderedDict: raise NotImplementedError @@ -323,74 +337,88 @@ def _create_td(self) -> None: shared_tensordict_parent = self._env_tensordict.clone() if self._single_task: - self.env_input_keys = sorted( - list(self.input_spec["_action_spec"].keys(True, True)) + self._env_input_keys = sorted( + list(self.input_spec["full_action_spec"].keys(True, True)) + list(self.state_spec.keys(True, True)), key=_sort_keys, ) - self.env_output_keys = [] - self.env_obs_keys = [] - for key in self.output_spec["_observation_spec"].keys(True, True): - self.env_output_keys.append(unravel_key(("next", key))) - self.env_obs_keys.append(key) - self.env_output_keys.append(unravel_key(("next", self.reward_key))) - self.env_output_keys.append(unravel_key(("next", self.done_key))) + self._env_output_keys = [] + self._env_obs_keys = [] + for key in self.output_spec["full_observation_spec"].keys(True, True): + self._env_output_keys.append(key) + self._env_obs_keys.append(key) + self._env_output_keys += self.reward_keys + self.done_keys else: env_input_keys = set() for meta_data in self.meta_data: - if meta_data.specs["input_spec", "_state_spec"] is not None: + if meta_data.specs["input_spec", "full_state_spec"] is not None: env_input_keys = env_input_keys.union( - meta_data.specs["input_spec", "_state_spec"].keys(True, True) + meta_data.specs["input_spec", "full_state_spec"].keys( + True, True + ) ) env_input_keys = env_input_keys.union( - meta_data.specs["input_spec", "_action_spec"].keys(True, True) + meta_data.specs["input_spec", "full_action_spec"].keys(True, True) ) env_output_keys = set() env_obs_keys = set() for meta_data in self.meta_data: env_obs_keys = env_obs_keys.union( key - for key in meta_data.specs["output_spec"]["_observation_spec"].keys( - True, True - ) + for key in meta_data.specs["output_spec"][ + "full_observation_spec" + ].keys(True, True) ) env_output_keys = env_output_keys.union( - unravel_key(("next", key)) - for key in meta_data.specs["output_spec"]["_observation_spec"].keys( + meta_data.specs["output_spec"]["full_observation_spec"].keys( True, True ) ) - env_output_keys = env_output_keys.union( - { - unravel_key(("next", self.reward_key)), - unravel_key(("next", self.done_key)), - } - ) - self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys) - self.env_input_keys = sorted(env_input_keys, key=_sort_keys) - self.env_output_keys = sorted(env_output_keys, key=_sort_keys) + env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) + self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) + self._env_input_keys = sorted(env_input_keys, key=_sort_keys) + self._env_output_keys = sorted(env_output_keys, key=_sort_keys) + reset_keys = self.reset_keys self._selected_keys = ( - set(self.env_output_keys) - .union(self.env_input_keys) - .union(self.env_obs_keys) + set(self._env_output_keys) + .union(self._env_input_keys) + .union(self._env_obs_keys) + .union(set(self.done_keys)) ) - self._selected_keys.add(self.done_key) - self._selected_keys.add("_reset") + self._selected_keys = self._selected_keys.union(reset_keys) - self._selected_reset_keys = self.env_obs_keys + [self.done_key] + ["_reset"] - self._selected_step_keys = self.env_output_keys + # input keys + self._selected_input_keys = { + _unravel_key_to_tuple(key) for key in self._env_input_keys + } + # output keys after reset + self._selected_reset_keys = { + _unravel_key_to_tuple(key) + for key in self._env_obs_keys + self.done_keys + reset_keys + } + # output keys after reset, filtered + self._selected_reset_keys_filt = { + unravel_keys(key) for key in self._env_obs_keys + self.done_keys + } + # output keys after step + self._selected_step_keys = { + _unravel_key_to_tuple(key) for key in self._env_output_keys + } if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( *self._selected_keys, + "next", strict=False, ) self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ - tensordict.select(*self._selected_keys, strict=False).to(self.device) + tensordict.select(*self._selected_keys, "next", strict=False).to( + self.device + ) for tensordict in shared_tensordict_parent ] shared_tensordict_parent = torch.stack( @@ -409,12 +437,13 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() - elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + if self.device.type == "cpu": + if self._share_memory: + for td in self.shared_tensordicts: + td.share_memory_() + elif self._memmap: + for td in self.shared_tensordicts: + td.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -424,10 +453,7 @@ def _create_td(self) -> None: self.shared_tensordict_parent.memmap_() if not self.shared_tensordict_parent.is_memmap(): raise RuntimeError("memmap_() failed") - self.shared_tensordicts = self.shared_tensordict_parent.unbind(0) - if self.pin_memory: - self.shared_tensordict_parent.pin_memory() def _start_workers(self) -> None: """Starts the various envs.""" @@ -536,27 +562,24 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - self._assert_tensordict_shape(tensordict) tensordict_in = tensordict.clone(False) + next_td = self.shared_tensordict_parent.get("next") for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. out_td = self._envs[i]._step(tensordict_in[i]) - out_td.update(tensordict_in[i].select(*self.env_input_keys)) - self.shared_tensordicts[i].update_( - out_td.select(*self.env_input_keys, *self.env_output_keys) - ) + next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps if self._single_task: - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) for key in self._selected_step_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there - out = self.shared_tensordict_parent.select( - *self._selected_step_keys, strict=False - ).clone() + out = next_td.select(*self._selected_step_keys, strict=False).clone() return out def _shutdown_workers(self) -> None: @@ -576,15 +599,19 @@ def set_seed( @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - if tensordict is not None and "_reset" in tensordict.keys(): - self._assert_tensordict_shape(tensordict) - _reset = tensordict.get("_reset") - if _reset.shape[-len(self.done_spec.shape) :] != self.done_spec.shape: - raise RuntimeError( - "_reset flag in tensordict should follow env.done_spec" - ) + + if tensordict is not None: + needs_resetting = _aggregate_resets(tensordict, reset_keys=self.reset_keys) + if needs_resetting.ndim > 2: + needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1) + if needs_resetting.ndim > 1: + needs_resetting = needs_resetting.any(-1) + elif not needs_resetting.ndim: + needs_resetting = needs_resetting.expand((self.num_workers,)) else: - _reset = torch.ones(self.done_spec.shape, dtype=torch.bool) + needs_resetting = torch.ones( + (self.num_workers,), device=self.device, dtype=torch.bool + ) for i, _env in enumerate(self._envs): if tensordict is not None: @@ -593,7 +620,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = None else: tensordict_ = None - if not _reset[i].any(): + + if not needs_resetting[i]: # We update the stored tensordict with the value of the "next" # key as one may be surprised to receive data that is not up-to-date # If we don't do this, the result of calling reset and skipping one env @@ -601,9 +629,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i]["next"].select( - *self._selected_reset_keys, strict=False - ) + self.shared_tensordicts[i] + .get("next") + .select(*self._selected_reset_keys, strict=False) ) if tensordict_ is not None: self.shared_tensordicts[i].update_( @@ -612,19 +640,20 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: continue _td = _env._reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( - _td.select(*self._selected_keys, strict=False) + _td.select(*self._selected_reset_keys, strict=False) ) - + selected_output_keys = self._selected_reset_keys_filt if self._single_task: # select + clone creates 2 tds, but we can create one only - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) - for key in self._selected_reset_keys: - if key != "_reset": - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) + for key in selected_output_keys: + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: return self.shared_tensordict_parent.select( - *[key for key in self._selected_reset_keys if key != "_reset"], + *selected_output_keys, strict=False, ).clone() @@ -682,52 +711,63 @@ class ParallelEnv(_BatchedEnv): __doc__ += _BatchedEnv.__doc__ def _start_workers(self) -> None: - _num_workers = self.num_workers + from torchrl.envs.env_creator import EnvCreator + + torch.set_num_threads(self.num_threads) + ctx = mp.get_context("spawn") + _num_workers = self.num_workers + self.parent_channels = [] self._workers = [] + self._events = [] if self.device.type == "cuda": self.event = torch.cuda.Event() else: self.event = None - for idx in range(_num_workers): - if self._verbose: - print(f"initiating worker {idx}") - # No certainty which module multiprocessing_context is - channel1, channel2 = ctx.Pipe() - env_fun = self.create_env_fn[idx] - if env_fun.__class__.__name__ != "EnvCreator": - env_fun = CloudpickleWrapper(env_fun) - - w = mp.Process( - target=_run_worker_pipe_shared_mem, - args=( - idx, - channel1, - channel2, - env_fun, - self.create_env_kwargs[idx], - False, - self.env_input_keys, - self.device, - self.allow_step_when_done, - ), - ) - w.daemon = True - w.start() - channel2.close() - self.parent_channels.append(channel1) - self._workers.append(w) - for channel1 in self.parent_channels: - msg = channel1.recv() + with clear_mpi_env_vars(): + for idx in range(_num_workers): + if self._verbose: + print(f"initiating worker {idx}") + # No certainty which module multiprocessing_context is + parent_pipe, child_pipe = ctx.Pipe() + event = ctx.Event() + self._events.append(event) + env_fun = self.create_env_fn[idx] + if not isinstance(env_fun, EnvCreator): + env_fun = CloudpickleWrapper(env_fun) + + process = _ProcessNoWarn( + target=_run_worker_pipe_shared_mem, + num_threads=self.num_sub_threads, + args=( + parent_pipe, + child_pipe, + env_fun, + self.create_env_kwargs[idx], + self.device, + event, + self.shared_tensordicts[idx], + self._selected_input_keys, + self._selected_reset_keys, + self._selected_step_keys, + self.has_lazy_inputs, + ), + ) + process.daemon = True + process.start() + child_pipe.close() + self.parent_channels.append(parent_pipe) + self._workers.append(process) + + for parent_pipe in self.parent_channels: + msg = parent_pipe.recv() assert msg == "started" # send shared tensordict to workers - for channel, shared_tensordict in zip( - self.parent_channels, self.shared_tensordicts - ): - channel.send(("init", shared_tensordict)) + for channel in self.parent_channels: + channel.send(("init", None)) self.is_closed = False @_check_start @@ -751,22 +791,15 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: ) for i, channel in enumerate(self.parent_channels): channel.send(("load_state_dict", state_dict[f"worker{i}"])) - for channel in self.parent_channels: - msg, _ = channel.recv() - if msg != "loaded": - raise RuntimeError(f"Expected 'loaded' but received {msg}") + for event in self._events: + event.wait() + event.clear() @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - self._assert_tensordict_shape(tensordict) - if self._single_task: + if self._single_task and not self.has_lazy_inputs: # this is faster than update_ but won't work for lazy stacks - for key in self.env_input_keys: - # self.shared_tensordict_parent.set( - # key, - # tensordict.get(key), - # inplace=True, - # ) + for key in self._env_input_keys: key = _unravel_key_to_tuple(key) self.shared_tensordict_parent._set_tuple( key, @@ -776,7 +809,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ) else: self.shared_tensordict_parent.update_( - tensordict.select(*self.env_input_keys, strict=False) + tensordict.select(*self._env_input_keys, strict=False) ) if self.event is not None: self.event.record() @@ -784,43 +817,42 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) - # keys = set() for i in range(self.num_workers): - msg, data = self.parent_channels[i].recv() - if msg != "step_result": - raise RuntimeError( - f"Expected 'step_result' but received {msg} from worker {i}" - ) - if data is not None: - self.shared_tensordicts[i].update_(data) + event = self._events[i] + event.wait() + event.clear() + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + next_td = self.shared_tensordict_parent.get("next") if self._single_task: - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) for key in self._selected_step_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there - out = self.shared_tensordict_parent.select( - *self._selected_step_keys, strict=False - ).clone() + out = next_td.select(*self._selected_step_keys, strict=False).clone() return out @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - cmd_out = "reset" - if tensordict is not None and "_reset" in tensordict.keys(): - self._assert_tensordict_shape(tensordict) - _reset = tensordict.get("_reset") - if _reset.shape[-len(self.done_spec.shape) :] != self.done_spec.shape: - raise RuntimeError( - "_reset flag in tensordict should follow env.done_spec" - ) + if tensordict is not None: + needs_resetting = _aggregate_resets(tensordict, reset_keys=self.reset_keys) + if needs_resetting.ndim > 2: + needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1) + if needs_resetting.ndim > 1: + needs_resetting = needs_resetting.any(-1) + elif not needs_resetting.ndim: + needs_resetting = needs_resetting.expand((self.num_workers,)) else: - _reset = torch.ones( - self.done_spec.shape, dtype=torch.bool, device=self.device + needs_resetting = torch.ones( + (self.num_workers,), device=self.device, dtype=torch.bool ) + workers = [] + for i, channel in enumerate(self.parent_channels): if tensordict is not None: tensordict_ = tensordict[i] @@ -828,7 +860,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = None else: tensordict_ = None - if not _reset[i].any(): + if not needs_resetting[i]: # We update the stored tensordict with the value of the "next" # key as one may be surprised to receive data that is not up-to-date # If we don't do this, the result of calling reset and skipping one env @@ -836,36 +868,36 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i]["next"].select( - *self._selected_reset_keys, strict=False - ) + self.shared_tensordicts[i] + .get("next") + .select(*self._selected_reset_keys, strict=False) ) if tensordict_ is not None: self.shared_tensordicts[i].update_( tensordict_.select(*self._selected_reset_keys, strict=False) ) continue - out = (cmd_out, tensordict_) + out = ("reset", tensordict_) channel.send(out) + workers.append(i) - for i, channel in enumerate(self.parent_channels): - if not _reset[i].any(): - continue - cmd_in, data = channel.recv() - if cmd_in != "reset_obs": - raise RuntimeError(f"received cmd {cmd_in} instead of reset_obs") - if data is not None: - self.shared_tensordicts[i].update_(data) + for i in workers: + event = self._events[i] + event.wait() + event.clear() + + selected_output_keys = self._selected_reset_keys_filt if self._single_task: # select + clone creates 2 tds, but we can create one only - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) - for key in self._selected_reset_keys: - if key != "_reset": - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) + for key in selected_output_keys: + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: return self.shared_tensordict_parent.select( - *[key for key in self._selected_reset_keys if key != "_reset"], + *selected_output_keys, strict=False, ).clone() @@ -878,15 +910,9 @@ def _shutdown_workers(self) -> None: for i, channel in enumerate(self.parent_channels): if self._verbose: print(f"closing {i}") - # try: channel.send(("close", None)) - # except: - # raise RuntimeError(f"closing {channel} number {i} failed") - msg, _ = channel.recv() - if msg != "closing": - raise RuntimeError( - f"Expected 'closing' but received {msg} from worker {i}" - ) + self._events[i].wait() + self._events[i].clear() del self.shared_tensordicts, self.shared_tensordict_parent @@ -974,15 +1000,17 @@ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> Ordered def _run_worker_pipe_shared_mem( - idx: int, parent_pipe: connection.Connection, child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - pin_memory: bool, - env_input_keys: Dict[str, Any], device: DEVICE_TYPING = None, - allow_step_when_done: bool = False, + mp_event: mp.Event = None, + shared_tensordict: TensorDictBase = None, + _selected_input_keys=None, + _selected_reset_keys=None, + _selected_step_keys=None, + has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: if device is None: @@ -1003,14 +1031,11 @@ def _run_worker_pipe_shared_mem( ) env = env_fun env = env.to(device) + del env_fun i = -1 initialized = False - # make sure that process can be closed - shared_tensordict = None - local_tensordict = None - child_pipe.send("started") while True: @@ -1032,8 +1057,10 @@ def _run_worker_pipe_shared_mem( if initialized: raise RuntimeError("worker already initialized") i = 0 - shared_tensordict = data next_shared_tensordict = shared_tensordict.get("next") + shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] + if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" @@ -1045,55 +1072,31 @@ def _run_worker_pipe_shared_mem( print(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - local_tensordict = data - local_tensordict = env._reset(tensordict=local_tensordict) - - if "_reset" in local_tensordict.keys(): - local_tensordict.del_("_reset") - if pin_memory: - local_tensordict.pin_memory() - shared_tensordict.update_(local_tensordict) + cur_td = env._reset(tensordict=data) + shared_tensordict.update_(cur_td) if event is not None: event.record() event.synchronize() - out = ("reset_obs", None) - child_pipe.send(out) + mp_event.set() elif cmd == "step": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - if local_tensordict is not None: - for key in env_input_keys: - # local_tensordict.set(key, shared_tensordict.get(key)) - key = _unravel_key_to_tuple(key) - local_tensordict._set_tuple( - key, - shared_tensordict._get_tuple(key, None), - inplace=False, - validated=True, - ) - else: - local_tensordict = shared_tensordict.clone(recurse=False) - local_tensordict = env._step(local_tensordict) - if pin_memory: - local_tensordict.pin_memory() - msg = "step_result" - next_shared_tensordict.update_(local_tensordict.get("next")) + next_td = env._step(shared_tensordict) + next_shared_tensordict.update_(next_td) if event is not None: event.record() event.synchronize() - out = (msg, None) - child_pipe.send(out) + mp_event.set() elif cmd == "close": - del shared_tensordict, local_tensordict, data + del shared_tensordict, data if not initialized: raise RuntimeError("call 'init' before closing") env.close() del env - - child_pipe.send(("closing", None)) + mp_event.set() child_pipe.close() if verbose: print(f"{pid} closed") @@ -1101,8 +1104,7 @@ def _run_worker_pipe_shared_mem( elif cmd == "load_state_dict": env.load_state_dict(data) - msg = "loaded" - child_pipe.send((msg, None)) + mp_event.set() elif cmd == "state_dict": state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) @@ -1133,299 +1135,3 @@ def _run_worker_pipe_shared_mem( else: # don't send env through pipe child_pipe.send(("_".join([cmd, "done"]), None)) - - -class MultiThreadedEnvWrapper(_EnvWrapper): - """Wrapper for envpool-based multithreaded environments.""" - - _verbose: bool = False - - def __init__( - self, - env: Optional["envpool.python.envpool.EnvPoolMixin"] = None, # noqa: F821 - **kwargs, - ): - if not _has_envpool: - raise ImportError( - "envpool python package or one of its dependencies (gym, treevalue) were not found. Please install these dependencies." - ) - if env is not None: - kwargs["env"] = env - self.num_workers = env.config["num_envs"] - # For synchronous mode batch size is equal to the number of workers - self.batch_size = torch.Size([self.num_workers]) - super().__init__(**kwargs) - - # Buffer to keep the latest observation for each worker - # It's a TensorDict when the observation consists of several variables, e.g. "position" and "velocity" - self.obs: Union[torch.tensor, TensorDict] = self.observation_spec.zero() - - def _check_kwargs(self, kwargs: Dict): - if "env" not in kwargs: - raise TypeError("Could not find environment key 'env' in kwargs.") - env = kwargs["env"] - import envpool - - if not isinstance(env, (envpool.python.envpool.EnvPoolMixin,)): - raise TypeError("env is not of type 'envpool.python.envpool.EnvPoolMixin'.") - - def _build_env(self, env: "envpool.python.envpool.EnvPoolMixin"): # noqa: F821 - return env - - def _make_specs( - self, env: "envpool.python.envpool.EnvPoolMixin" # noqa: F821 - ) -> None: # noqa: F821 - from torchrl.envs.libs.gym import set_gym_backend - - with set_gym_backend("gym"): - self.action_spec = self._get_action_spec() - output_spec = self._get_output_spec() - self.observation_spec = output_spec["_observation_spec"] - self.reward_spec = output_spec["_reward_spec"] - self.done_spec = output_spec["_done_spec"] - - def _init_env(self) -> Optional[int]: - pass - - def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: - if tensordict is not None: - reset_workers = tensordict.get("_reset", None) - else: - reset_workers = None - if reset_workers is not None: - reset_data = self._env.reset(np.where(reset_workers.cpu().numpy())[0]) - else: - reset_data = self._env.reset() - tensordict_out = self._transform_reset_output(reset_data, reset_workers) - self.is_closed = False - return tensordict_out - - @torch.no_grad() - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - action = tensordict.get(self.action_key) - # Action needs to be moved to CPU and converted to numpy before being passed to envpool - action = action.to(torch.device("cpu")) - step_output = self._env.step(action.numpy()) - tensordict_out = self._transform_step_output(step_output) - return tensordict_out.select().set("next", tensordict_out) - - def _get_action_spec(self) -> TensorSpec: - # local import to avoid importing gym in the script - from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform - - # Envpool provides Gym-compatible specs as env.spec.action_space and - # DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones. - - # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase - action_spec = _gym_to_torchrl_spec_transform( - self._env.spec.action_space, - device=self.device, - categorical_action_encoding=True, - ) - action_spec = self._add_shape_to_spec(action_spec) - return action_spec - - def _get_output_spec(self) -> TensorSpec: - return CompositeSpec( - _observation_spec=self._get_observation_spec(), - _reward_spec=self._get_reward_spec(), - _done_spec=self._get_done_spec(), - shape=(self.num_workers,), - device=self.device, - ) - - def _get_observation_spec(self) -> TensorSpec: - # local import to avoid importing gym in the script - from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform - - # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase - observation_spec = _gym_to_torchrl_spec_transform( - self._env.spec.observation_space, - device=self.device, - categorical_action_encoding=True, - ) - observation_spec = self._add_shape_to_spec(observation_spec) - if isinstance(observation_spec, CompositeSpec): - return observation_spec - return CompositeSpec( - observation=observation_spec, - shape=(self.num_workers,), - device=self.device, - ) - - def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec: - return spec.expand((self.num_workers, *spec.shape)) - - def _get_reward_spec(self) -> TensorSpec: - return UnboundedContinuousTensorSpec( - device=self.device, - shape=self.batch_size, - ) - - def _get_done_spec(self) -> TensorSpec: - return DiscreteTensorSpec( - 2, - device=self.device, - shape=self.batch_size, - dtype=torch.bool, - ) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(num_workers={self.num_workers}, device={self.device})" - - def _transform_reset_output( - self, - envpool_output: Tuple[ - Union["treevalue.TreeValue", np.ndarray], Any # noqa: F821 - ], - reset_workers: Optional[torch.Tensor], - ): - """Process output of envpool env.reset.""" - import treevalue - - observation, _ = envpool_output - if reset_workers is not None: - # Only specified workers were reset - need to set observation buffer values only for them - if isinstance(observation, treevalue.TreeValue): - # If observation contain several fields, it will be returned as treevalue.TreeValue. - # Convert to treevalue.FastTreeValue to allow indexing - observation = treevalue.FastTreeValue(observation) - self.obs[reset_workers] = self._treevalue_or_numpy_to_tensor_or_dict( - observation - ) - else: - # All workers were reset - rewrite the whole observation buffer - self.obs = TensorDict( - self._treevalue_or_numpy_to_tensor_or_dict(observation), self.batch_size - ) - - obs = self.obs.clone(False) - obs.update({self.done_key: self.done_spec.zero()}) - return obs - - def _transform_step_output( - self, envpool_output: Tuple[Any, Any, Any, ...] - ) -> TensorDict: - """Process output of envpool env.step.""" - obs, reward, done, *_ = envpool_output - - obs = self._treevalue_or_numpy_to_tensor_or_dict(obs) - obs.update({self.reward_key: torch.tensor(reward), self.done_key: done}) - self.obs = tensordict_out = TensorDict( - obs, - batch_size=self.batch_size, - device=self.device, - ) - return tensordict_out - - def _treevalue_or_numpy_to_tensor_or_dict( - self, x: Union["treevalue.TreeValue", np.ndarray] # noqa: F821 - ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - """Converts observation returned by EnvPool. - - EnvPool step and reset return observation as a numpy array or a TreeValue of numpy arrays, which we convert - to a tensor or a dictionary of tensors. Currently only supports depth 1 trees, but can easily be extended to - arbitrary depth if necessary. - """ - import treevalue - - if isinstance(x, treevalue.TreeValue): - ret = self._treevalue_to_dict(x) - elif not isinstance(x, dict): - ret = {"observation": torch.tensor(x)} - else: - ret = x - return ret - - def _treevalue_to_dict( - self, tv: "treevalue.TreeValue" # noqa: F821 - ) -> Dict[str, Any]: - """Converts TreeValue to a dictionary. - - Currently only supports depth 1 trees, but can easily be extended to arbitrary depth if necessary. - """ - import treevalue - - return {k[0]: torch.tensor(v) for k, v in treevalue.flatten(tv)} - - def _set_seed(self, seed: Optional[int]): - if seed is not None: - print( - "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\ - supported by envpool. Please create a new environment, passing the seed to the constructor." - ) - - -class MultiThreadedEnv(MultiThreadedEnvWrapper): - """Multithreaded execution of environments based on EnvPool. - - An alternative to ParallelEnv based on multithreading. It's faster, as it doesn't require new process spawning, but - less flexible, as it only supports environments implemented in EnvPool library. - Currently only supports synchronous execution mode, when the batch size is equal to the number of workers, see - https://envpool.readthedocs.io/en/latest/content/python_interface.html#batch-size. - - >>> env = MultiThreadedEnv(num_workers=3, env_name="Pendulum-v1") - >>> env.reset() - >>> env.rand_step() - >>> env.rollout(5) - >>> env.close() - - Args: - num_workers: number of worker threads to create. - env_name: name of the environment, corresponding to task_id in EnvPool. - create_env_kwargs: additional arguments which will be passed to envpool.make. - """ - - def __init__( - self, - num_workers: int, - env_name: str, - create_env_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ): - self.env_name = env_name.replace("ALE/", "") # Naming convention of EnvPool - self.num_workers = num_workers - self.batch_size = torch.Size([num_workers]) - self.create_env_kwargs = create_env_kwargs or {} - - kwargs["num_workers"] = num_workers - kwargs["env_name"] = self.env_name - kwargs["create_env_kwargs"] = create_env_kwargs - super().__init__(**kwargs) - - def _build_env( - self, - env_name: str, - num_workers: int, - create_env_kwargs: Optional[Dict[str, Any]], - ) -> Any: - import envpool - - create_env_kwargs = create_env_kwargs or {} - env = envpool.make( - task_id=env_name, - env_type="gym", - num_envs=num_workers, - gym_reset_return_info=True, - **create_env_kwargs, - ) - return super()._build_env(env) - - def _set_seed(self, seed: Optional[int]): - """Library EnvPool only supports setting a seed by recreating the environment.""" - if seed is not None: - logging.debug("Recreating EnvPool environment to set seed.") - self.create_env_kwargs["seed"] = seed - self._env = self._build_env( - env_name=self.env_name, - num_workers=self.num_workers, - create_env_kwargs=self.create_env_kwargs, - ) - - def _check_kwargs(self, kwargs: Dict): - for arg in ["num_workers", "env_name", "create_env_kwargs"]: - if arg not in kwargs: - raise TypeError(f"Expected '{arg}' to be part of kwargs") - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(env={self.env_name}, num_workers={self.num_workers}, device={self.device})" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 15168f20411..55e057ffa47 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -7,13 +7,14 @@ import abc from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import numpy as np import torch import torch.nn as nn +from tensordict import unravel_key from tensordict.tensordict import TensorDictBase - +from tensordict.utils import NestedKey from torchrl._utils import prod, seed_generator from torchrl.data.tensor_specs import ( @@ -23,7 +24,12 @@ UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs.utils import get_available_libraries, step_mdp +from torchrl.envs.utils import ( + _replace_last, + get_available_libraries, + step_mdp, + terminated_or_truncated, +) LIBRARIES = get_available_libraries() @@ -77,7 +83,12 @@ def specs(self, value: CompositeSpec): @staticmethod def metadata_from_env(env) -> EnvMetaData: tensordict = env.fake_tensordict().clone() - tensordict.set("_reset", torch.zeros_like(tensordict.get(env.done_key))) + + for done_key in env.done_keys: + tensordict.set( + _replace_last(done_key, "_reset"), + torch.zeros_like(tensordict.get(("next", done_key))), + ) specs = env.specs.to("cpu") @@ -89,7 +100,7 @@ def metadata_from_env(env) -> EnvMetaData: return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked) def expand(self, *size: int) -> EnvMetaData: - tensordict = self.tensordict.expand(*size).to_tensordict() + tensordict = self.tensordict.expand(*size).clone() batch_size = torch.Size(list(size)) return EnvMetaData( tensordict, @@ -118,64 +129,22 @@ def to(self, device: DEVICE_TYPING) -> EnvMetaData: ) -class EnvBase(nn.Module, metaclass=abc.ABCMeta): - """Abstract environment parent class. +class _EnvPostInit(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance: EnvBase = super().__call__(*args, **kwargs) + # we create the done spec by adding a done/terminated entry if one is missing + instance._create_done_specs() + # we access lazy attributed to make sure they're built properly. + # This isn't done in `__init__` because we don't know if supre().__init__ + # will be called before or after the specs, batch size etc are set. + _ = instance.done_spec + _ = instance.reward_spec + _ = instance.state_spec + return instance - Properties: - observation_spec (CompositeSpec): sampling spec of the observations. Must be a - :class:`torchrl.data.CompositeSpec` instance. The keys listed in the - spec are directly accessible after reset. - In TorchRL, even though they are not properly speaking "observations" - all info, states, results of transforms etc. are stored in the - observation_spec. Therefore, "observation_spec" should be thought as - a generic data container for environment outputs that are not done - or reward data. - reward_spec (TensorSpec): the (leaf) spec of the reward. If the reward - is nested within a tensordict, its location can be accessed via - the ``reward_key`` attribute: - - >>> # accessing reward spec: - >>> reward_spec = env.reward_spec - >>> reward_spec = env.output_spec['_reward_spec'][env.reward_key] - >>> # accessing reward: - >>> reward = env.fake_tensordict()[('next', *env.reward_key)] - - done_spec (TensorSpec): the (leaf) spec of the done. If the done - is nested within a tensordict, its location can be accessed via - the ``done_key`` attribute. - - >>> # accessing done spec: - >>> done_spec = env.done_spec - >>> done_spec = env.output_spec['_done_spec'][env.done_key] - >>> # accessing done: - >>> done = env.fake_tensordict()[('next', *env.done_key)] - - action_spec (TensorSpec): the ampling spec of the actions. This attribute - is contained in input_spec. - - >>> # accessing action spec: - >>> action_spec = env.action_spec - >>> action_spec = env.input_spec['_action_spec'][env.action_key] - >>> # accessing action: - >>> action = env.fake_tensordict()[env.action_key] - - output_spec (CompositeSpec): The container for all output specs (reward, - done and observation). - input_spec (CompositeSpec): the container for all input specs (actions - and possibly others). - batch_size (torch.Size): number of environments contained in the instance; - device (torch.device): device where the env input and output are expected to live - run_type_checks (bool): if ``True``, the observation and reward dtypes - will be compared against their respective spec and an exception - will be raised if they don't match. - Defaults to False. - - .. note:: - The usage of ``done_key``, ``reward_key`` and ``action_key`` is aimed at - facilitating the custom placement of done, reward and action data within - the tensordict structures produced and read by the environment. - In most cases, these attributes can be ignored and the default values - (``"done"``, ``"reward"`` and ``"action"``) can be used. + +class EnvBase(nn.Module, metaclass=_EnvPostInit): + """Abstract environment parent class. Methods: step (TensorDictBase -> TensorDictBase): step in the environment @@ -192,20 +161,22 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): torch.Size([]) >>> env.input_spec CompositeSpec( - action: BoundedTensorSpec( - shape=torch.Size([1]), - space=ContinuousBox( - minimum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), - maximum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), - device=cpu, - dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([])) + full_state_spec: None, + full_action_spec: CompositeSpec( + action: BoundedTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> env.action_spec BoundedTensorSpec( shape=torch.Size([1]), space=ContinuousBox( - minimum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), - maximum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous) @@ -214,8 +185,8 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): observation: BoundedTensorSpec( shape=torch.Size([3]), space=ContinuousBox( - minimum=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), - maximum=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])) @@ -236,23 +207,23 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): >>> # the output_spec contains all the expected outputs >>> env.output_spec CompositeSpec( - observation: CompositeSpec( - observation: BoundedTensorSpec( - shape=torch.Size([3]), - space=ContinuousBox( - minimum=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), - maximum=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), - device=cpu, - dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([])), - reward: CompositeSpec( + full_reward_spec: CompositeSpec( reward: UnboundedContinuousTensorSpec( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - done: CompositeSpec( + full_observation_spec: CompositeSpec( + observation: BoundedTensorSpec( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])), + full_done_spec: CompositeSpec( done: DiscreteTensorSpec( shape=torch.Size([1]), space=DiscreteBox(n=2), @@ -264,14 +235,18 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): def __init__( self, - device: DEVICE_TYPING = "cpu", + device: DEVICE_TYPING = None, dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, + allow_done_after_reset: bool = False, ): - self.__dict__["_done_key"] = None - self.__dict__["_reward_key"] = None - self.__dict__["_action_key"] = None + if device is None: + device = torch.device("cpu") + self.__dict__.setdefault("_done_keys", None) + self.__dict__.setdefault("_reward_keys", None) + self.__dict__.setdefault("_action_keys", None) + self.__dict__.setdefault("_batch_size", None) if device is not None: self.__dict__["_device"] = torch.device(device) output_spec = self.__dict__.get("_output_spec", None) @@ -290,6 +265,7 @@ def __init__( # it's already been set self.batch_size = torch.Size(batch_size) self._run_type_checks = run_type_checks + self._allow_done_after_reset = allow_done_after_reset @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): @@ -323,7 +299,15 @@ def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): return super().__new__(cls) def __setattr__(self, key, value): - if key in ("_input_spec", "_observation_spec", "_action_spec", "_reward_spec"): + if key in ( + "_input_spec", + "_observation_spec", + "_action_spec", + "_reward_spec", + "_output_spec", + "_state_spec", + "_done_spec", + ): raise AttributeError( "To set an environment spec, please use `env.observation_spec = obs_spec` (without the leading" " underscore)." @@ -332,7 +316,7 @@ def __setattr__(self, key, value): @property def batch_locked(self) -> bool: - """Whether the environnement can be used with a batch size different from the one it was initialized with or not. + """Whether the environment can be used with a batch size different from the one it was initialized with or not. If True, the env needs to be used with a tensordict having the same batch size as the env. batch_locked is an immutable property. @@ -353,7 +337,7 @@ def run_type_checks(self, run_type_checks: bool) -> None: @property def batch_size(self) -> torch.Size: - _batch_size = getattr(self, "_batch_size", None) + _batch_size = self.__dict__["_batch_size"] if _batch_size is None: _batch_size = self._batch_size = torch.Size([]) return _batch_size @@ -398,10 +382,40 @@ def ndim(self): # Parent specs: input and output spec. @property def input_spec(self) -> TensorSpec: + """Input spec. + + The composite spec containing all specs for data input to the environments. + + It contains: + + - "full_action_spec": the spec of the input actions + - "full_state_spec": the spec of all other environment inputs + + This attibute is locked and should be read-only. + Instead, to set the specs contained in it, use the respective properties. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = GymEnv("Pendulum-v1") + >>> env.input_spec + CompositeSpec( + full_state_spec: None, + full_action_spec: CompositeSpec( + action: BoundedTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + + """ input_spec = self.__dict__.get("_input_spec", None) if input_spec is None: input_spec = CompositeSpec( - _state_spec=None, + full_state_spec=None, shape=self.batch_size, device=self.device, ).lock_() @@ -414,6 +428,50 @@ def input_spec(self, value: TensorSpec) -> None: @property def output_spec(self) -> TensorSpec: + """Output spec. + + The composite spec containing all specs for data output from the environments. + + It contains: + + - "full_reward_spec": the spec of reward + - "full_done_spec": the spec of done + - "full_observation_spec": the spec of all other environment outputs + + This attibute is locked and should be read-only. + Instead, to set the specs contained in it, use the respective properties. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = GymEnv("Pendulum-v1") + >>> env.output_spec + CompositeSpec( + full_reward_spec: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])), + full_observation_spec: CompositeSpec( + observation: BoundedTensorSpec( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])), + full_done_spec: CompositeSpec( + done: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + + """ output_spec = self.__dict__.get("_output_spec", None) if output_spec is None: output_spec = CompositeSpec( @@ -428,60 +486,133 @@ def output_spec(self, value: TensorSpec) -> None: raise RuntimeError("output_spec is protected.") # Action spec - def _get_action_key(self): - keys = self.input_spec["_action_spec"].keys(True, True) - for key in keys: - # the first key is the action - if not isinstance(key, tuple): - key = (key,) - break - else: + def _get_action_keys(self): + keys = self.input_spec["full_action_spec"].keys(True, True) + if not len(keys): raise AttributeError("Could not find action spec") - self.__dict__["_action_key"] = key - return key + keys = list(keys) + self.__dict__["_action_keys"] = keys + return keys @property - def action_key(self): - """The action key of an environment. - - By default, non-nested keys are stored in the 'action' key. + def action_keys(self) -> List[NestedKey]: + """The action keys of an environment. - If the action is in a nested tensordict, this property will return its - location. + By default, there will only be one key named "action". """ - out = self._action_key + out = self._action_keys if out is None: - out = self._get_action_key() + out = self._get_action_keys() return out + @property + def action_key(self) -> NestedKey: + """The action key of an environment. + + By default, this will be "action". + + If there is more than one action key in the environment, this function will raise an exception. + """ + if len(self.action_keys) > 1: + raise KeyError( + "action_key requested but more than one key present in the environment" + ) + return self.action_keys[0] + # Action spec: action specs belong to input_spec @property def action_spec(self) -> TensorSpec: - """The ``action`` leaf spec. + """The ``action`` spec. - This property will always return the leaf spec of the action attribute, - which can be accessed in a typical rollout via + The ``action_spec`` is always stored as a composite spec. - >>> fake_td = env.fake_tensordict() # a typical tensordict - >>> action = fake_td[env.action_key] + If the action spec is provided as a simple spec, this will be returned. + + >>> env.action_spec = UnboundedContinuousTensorSpec(1) + >>> env.action_spec + UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + + If the action spec is provided as a composite spec and contains only one leaf, + this function will return just the leaf. + + >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1)}}) + >>> env.action_spec + UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + + If the action spec is provided as a composite spec and has more than one leaf, + this function will return the whole spec. + + >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1), "another_action": DiscreteTensorSpec(1)}}) + >>> env.action_spec + CompositeSpec( + nested: CompositeSpec( + action: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + another_action: DiscreteTensorSpec( + shape=torch.Size([]), + space=DiscreteBox(n=1), + device=cpu, + dtype=torch.int64, + domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + To retrieve the full spec passed, use: + + >>> env.input_spec["full_action_spec"] This property is mutable. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = GymEnv("Pendulum-v1") + >>> env.action_spec + BoundedTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) """ try: - action_spec = self.input_spec["_action_spec"] + action_spec = self.input_spec["full_action_spec"] except (KeyError, AttributeError): raise KeyError("Failed to find the action_spec.") - try: - out = action_spec[self.action_key] - except KeyError: - # the key may have changed - raise KeyError( - "The action_key attribute seems to have changed. " - "This occurs when a action_spec is updated without " - "calling `env.action_spec = new_spec`. " - "Make sure you rely on this type of command " - "to set the action and other specs." - ) + + if len(self.action_keys) > 1: + out = action_spec + else: + try: + out = action_spec[self.action_key] + except KeyError: + # the key may have changed + raise KeyError( + "The action_key attribute seems to have changed. " + "This occurs when a action_spec is updated without " + "calling `env.action_spec = new_spec`. " + "Make sure you rely on this type of command " + "to set the action and other specs." + ) return out @@ -491,9 +622,17 @@ def action_spec(self, value: TensorSpec) -> None: self.input_spec.unlock_() device = self.input_spec.device try: - delattr(self, "_action_key") + delattr(self, "_action_keys") except AttributeError: pass + if not hasattr(value, "shape"): + raise TypeError( + f"action_spec of type {type(value)} do not have a shape attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) if isinstance(value, CompositeSpec): for _ in value.values(True, True): # noqa: B007 @@ -508,77 +647,164 @@ def action_spec(self, value: TensorSpec) -> None: action=value.to(device), shape=self.batch_size, device=device ) - self.input_spec["_action_spec"] = value.to(device) - self._get_action_key() + self.input_spec["full_action_spec"] = value.to(device) + self._get_action_keys() finally: self.input_spec.lock_() + @property + def full_action_spec(self) -> CompositeSpec: + """The full action spec. + + ``full_action_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + that contains all the action entries. + + Examples: + >>> from torchrl.envs import BraxEnv + >>> for envname in BraxEnv.available_envs: + ... break + >>> env = BraxEnv(envname) + >>> env.full_action_spec + CompositeSpec( + action: BoundedTensorSpec( + shape=torch.Size([8]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])) + + """ + return self.input_spec["full_action_spec"] + + @full_action_spec.setter + def full_action_spec(self, spec: CompositeSpec) -> None: + self.action_spec = spec + # Reward spec - def _get_reward_key(self): - keys = self.output_spec["_reward_spec"].keys(True, True) - for key in keys: - # the first key is the reward - if not isinstance(key, tuple): - key = (key,) - break - else: + def _get_reward_keys(self): + keys = self.output_spec["full_reward_spec"].keys(True, True) + if not len(keys): raise AttributeError("Could not find reward spec") - self.__dict__["_reward_key"] = key - return key + keys = list(keys) + self.__dict__["_reward_keys"] = keys + return keys + + @property + def reward_keys(self) -> List[NestedKey]: + """The reward keys of an environment. + + By default, there will only be one key named "reward". + """ + result = list(self.full_reward_spec.keys(True, True)) + return result @property def reward_key(self): """The reward key of an environment. - By default, non-nested keys are stored in the ``'reward'`` entry. + By default, this will be "reward". - If the reward is in a nested tensordict, this property will return its - location. + If there is more than one reward key in the environment, this function will raise an exception. """ - out = self._reward_key - if out is None: - out = self._get_reward_key() - return out + if len(self.reward_keys) > 1: + raise KeyError( + "reward_key requested but more than one key present in the environment" + ) + return self.reward_keys[0] - # Done spec: reward specs belong to output_spec + # Reward spec: reward specs belong to output_spec @property def reward_spec(self) -> TensorSpec: - """The ``reward`` leaf spec. + """The ``reward`` spec. + + The ``reward_spec`` is always stored as a composite spec. + + If the reward spec is provided as a simple spec, this will be returned. - This property will always return the leaf spec of the reward attribute, - which can be accessed in a typical rollout via + >>> env.reward_spec = UnboundedContinuousTensorSpec(1) + >>> env.reward_spec + UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + + If the reward spec is provided as a composite spec and contains only one leaf, + this function will return just the leaf. + + >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1)}}) + >>> env.reward_spec + UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + + If the reward spec is provided as a composite spec and has more than one leaf, + this function will return the whole spec. + + >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1), "another_reward": DiscreteTensorSpec(1)}}) + >>> env.reward_spec + CompositeSpec( + nested: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + another_reward: DiscreteTensorSpec( + shape=torch.Size([]), + space=DiscreteBox(n=1), + device=cpu, + dtype=torch.int64, + domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + To retrieve the full spec passed, use: - >>> fake_td = env.fake_tensordict() # a typical tensordict - >>> reward = fake_td[("next", *env.reward_key)] + >>> env.output_spec["full_reward_spec"] This property is mutable. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = GymEnv("Pendulum-v1") + >>> env.reward_spec + UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous) """ try: - reward_spec = self.output_spec["_reward_spec"] + reward_spec = self.output_spec["full_reward_spec"] except (KeyError, AttributeError): # populate the "reward" entry - # this will be raised if there is not _reward_spec (unlikely) or no reward_key + # this will be raised if there is not full_reward_spec (unlikely) or no reward_key # Since output_spec is lazily populated with an empty composite spec for # reward_spec, the second case is much more likely to occur. - self.reward_spec = out = UnboundedContinuousTensorSpec( + self.reward_spec = UnboundedContinuousTensorSpec( shape=(*self.batch_size, 1), device=self.device, ) - reward_spec = self.output_spec["_reward_spec"] - finally: - try: - out = reward_spec[self.reward_key] - except KeyError: - # the key may have changed - raise KeyError( - "The reward_key attribute seems to have changed. " - "This occurs when a reward_spec is updated without " - "calling `env.reward_spec = new_spec`. " - "Make sure you rely on this type of command " - "to set the reward and other specs." - ) + reward_spec = self.output_spec["full_reward_spec"] - return out + reward_keys = self.reward_keys + if len(reward_keys) > 1 or not len(reward_keys): + return reward_spec + else: + return reward_spec[self.reward_keys[0]] @reward_spec.setter def reward_spec(self, value: TensorSpec) -> None: @@ -586,7 +812,7 @@ def reward_spec(self, value: TensorSpec) -> None: self.output_spec.unlock_() device = self.output_spec.device try: - delattr(self, "_reward_key") + delattr(self, "_reward_keys") except AttributeError: pass if not hasattr(value, "shape"): @@ -596,10 +822,10 @@ def reward_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): - for nestedval in value.values(True, True): # noqa: B007 + for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( @@ -607,89 +833,268 @@ def reward_spec(self, value: TensorSpec) -> None: "This is currently not permitted." ) else: - nestedval = value value = CompositeSpec( reward=value.to(device), shape=self.batch_size, device=device ) - if len(nestedval.shape) == 0: - raise RuntimeError( - "the reward_spec shape cannot be empty (this error" - " usually comes from trying to set a reward_spec" - " with a null number of dimensions. Try using a multidimensional" - " spec instead, for instance with a singleton dimension at the tail)." - ) - self.output_spec["_reward_spec"] = value.to(device) - self._get_reward_key() + for leaf in value.values(True, True): + if len(leaf.shape) == 0: + raise RuntimeError( + "the reward_spec's leafs shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." + ) + self.output_spec["full_reward_spec"] = value.to(device) + self._get_reward_keys() finally: self.output_spec.lock_() + @property + def full_reward_spec(self) -> CompositeSpec: + """The full reward spec. + + ``full_reward_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + that contains all the reward entries. + + Examples: + >>> import gymnasium + >>> from torchrl.envs import GymWrapper, TransformedEnv, RenameTransform + >>> base_env = GymWrapper(gymnasium.make("Pendulum-v1")) + >>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward"))) + >>> env.full_reward_spec + CompositeSpec( + nested: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=None, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + """ + return self.output_spec["full_reward_spec"] + + @full_reward_spec.setter + def full_reward_spec(self, spec: CompositeSpec) -> None: + self.reward_spec = spec + # done spec - def _get_done_key(self): - keys = self.output_spec["_done_spec"].keys(True, True) - for key in keys: - # the first key is the reward - if not isinstance(key, tuple): - key = (key,) - break - else: - raise AttributeError( - f"Could not find done spec: {self.output_spec['_done_spec']}" + def _get_done_keys(self): + if "full_done_spec" not in self.output_spec.keys(): + # populate the "done" entry + # this will be raised if there is not full_done_spec (unlikely) or no done_key + # Since output_spec is lazily populated with an empty composite spec for + # done_spec, the second case is much more likely to occur. + self.done_spec = DiscreteTensorSpec( + n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device ) - self.__dict__["_done_key"] = key - return key + + keys = self.output_spec["full_done_spec"].keys(True, True) + if not len(keys): + raise AttributeError("Could not find done spec") + keys = list(keys) + self.__dict__["_done_keys"] = keys + return keys + + @property + def done_keys(self) -> List[NestedKey]: + """The done keys of an environment. + + By default, there will only be one key named "done". + """ + result = list(self.full_done_spec.keys(True, True)) + return result @property def done_key(self): """The done key of an environment. - By default, non-nested keys are stored in the ``'done'`` entry. + By default, this will be "done". - If the done is in a nested tensordict, this property will return its - location. + If there is more than one done key in the environment, this function will raise an exception. """ - out = self._done_key - if out is None: - out = self._get_done_key() - return out + if len(self.done_keys) > 1: + raise KeyError( + "done_key requested but more than one key present in the environment" + ) + return self.done_keys[0] + + @property + def full_done_spec(self) -> CompositeSpec: + """The full done spec. + + ``full_done_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + that contains all the done entries. + It can be used to generate fake data with a structure that mimics the + one obtained at runtime. + + Examples: + >>> import gymnasium + >>> from torchrl.envs import GymWrapper + >>> env = GymWrapper(gymnasium.make("Pendulum-v1")) + >>> env.full_done_spec + CompositeSpec( + done: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete), + truncated: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete), device=cpu, shape=torch.Size([])) + + """ + return self.output_spec["full_done_spec"] + + @full_done_spec.setter + def full_done_spec(self, spec: CompositeSpec) -> None: + self.done_spec = spec # Done spec: done specs belong to output_spec @property def done_spec(self) -> TensorSpec: - """The ``done`` leaf spec. + """The ``done`` spec. + + The ``done_spec`` is always stored as a composite spec. - This property will always return the leaf spec of the done attribute, - which can be accessed in a typical rollout via + If the done spec is provided as a simple spec, this will be returned. - >>> fake_td = env.fake_tensordict() # a typical tensordict - >>> done = fake_td[("next", *env.done_key)] + >>> env.done_spec = DiscreteTensorSpec(2, dtype=torch.bool) + >>> env.done_spec + DiscreteTensorSpec( + shape=torch.Size([]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete) + + If the done spec is provided as a composite spec and contains only one leaf, + this function will return just the leaf. + + >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool)}}) + >>> env.done_spec + DiscreteTensorSpec( + shape=torch.Size([]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete) + + If the done spec is provided as a composite spec and has more than one leaf, + this function will return the whole spec. + + >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool), "another_done": DiscreteTensorSpec(2, dtype=torch.bool)}}) + >>> env.done_spec + CompositeSpec( + nested: CompositeSpec( + done: DiscreteTensorSpec( + shape=torch.Size([]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete), + another_done: DiscreteTensorSpec( + shape=torch.Size([]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + To always retrieve the full spec passed, use: + + >>> env.output_spec["full_done_spec"] This property is mutable. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = GymEnv("Pendulum-v1") + >>> env.done_spec + DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + device=cpu, + dtype=torch.bool, + domain=discrete) + """ + done_spec = self.output_spec["full_done_spec"] + return done_spec + + def _create_done_specs(self): + """Reads through the done specs and makes it so that it's complete. + + If the done_specs contain only a ``"done"`` entry, a similar ``"terminated"`` entry is created. + Same goes if only ``"terminated"`` key is present. + + If none of ``"done"`` and ``"terminated"`` can be found and the spec is not + empty, nothing is changed. + """ try: - done_spec = self.output_spec["_done_spec"] - except (KeyError, AttributeError): - # populate the "done" entry - # this will be raised if there is not _done_spec (unlikely) or no done_key - # Since output_spec is lazily populated with an empty composite spec for - # done_spec, the second case is much more likely to occur. - self.done_spec = DiscreteTensorSpec( - n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device + full_done_spec = self.output_spec["full_done_spec"] + except KeyError: + full_done_spec = CompositeSpec( + shape=self.output_spec.shape, device=self.output_spec.device ) - done_spec = self.output_spec["_done_spec"] - finally: - try: - out = done_spec[self.done_key] - except KeyError: - # the key may have changed - raise KeyError( - "The done_key attribute seems to have changed. " - "This occurs when a done_spec is updated without " - "calling `env.done_spec = new_spec`. " - "Make sure you rely on this type of command " - "to set the done and other specs." + full_done_spec["done"] = DiscreteTensorSpec( + n=2, + shape=(*full_done_spec.shape, 1), + dtype=torch.bool, + device=self.device, + ) + full_done_spec["terminated"] = DiscreteTensorSpec( + n=2, + shape=(*full_done_spec.shape, 1), + dtype=torch.bool, + device=self.device, + ) + self.output_spec.unlock_() + self.output_spec["full_done_spec"] = full_done_spec + self.output_spec.lock_() + return + + def check_local_done(spec): + shape = None + for key, item in list( + spec.items() + ): # list to avoid error due to in-loop changes + # in the case where the spec is non-empty and there is no done and no terminated, we do nothing + if key == "done" and "terminated" not in spec.keys(): + spec["terminated"] = item.clone() + elif key == "terminated" and "done" not in spec.keys(): + spec["done"] = item.clone() + elif isinstance(item, CompositeSpec): + check_local_done(item) + else: + if shape is None: + shape = item.shape + continue + # checks that all shape match + if shape != item.shape: + raise ValueError( + f"All shapes should match in done_spec {spec} (shape={shape}, key={key})." + ) + + # if the spec is empty, we need to add a done and terminated manually + if spec.is_empty(): + spec["done"] = DiscreteTensorSpec( + n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device + ) + spec["terminated"] = DiscreteTensorSpec( + n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) - return out + self.output_spec.unlock_() + check_local_done(full_done_spec) + self.output_spec["full_done_spec"] = full_done_spec + self.output_spec.lock_() + return @done_spec.setter def done_spec(self, value: TensorSpec) -> None: @@ -697,7 +1102,7 @@ def done_spec(self, value: TensorSpec) -> None: self.output_spec.unlock_() device = self.output_spec.device try: - delattr(self, "_done_key") + delattr(self, "_done_keys") except AttributeError: pass if not hasattr(value, "shape"): @@ -707,10 +1112,10 @@ def done_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): - for nestedval in value.values(True, True): # noqa: B007 + for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( @@ -718,32 +1123,61 @@ def done_spec(self, value: TensorSpec) -> None: "This is currently not permitted." ) else: - nestedval = value value = CompositeSpec( - done=value.to(device), shape=self.batch_size, device=device + done=value.to(device), + terminated=value.to(device), + shape=self.batch_size, + device=device, ) - if len(nestedval.shape) == 0: - raise RuntimeError( - "the done_spec shape cannot be empty (this error" - " usually comes from trying to set a done_spec" - " with a null number of dimensions. Try using a multidimensional" - " spec instead, for instance with a singleton dimension at the tail)." - ) - if len(list(value.keys())) == 0: - raise RuntimeError - self.output_spec["_done_spec"] = value.to(device) - self._get_done_key() + for leaf in value.values(True, True): + if len(leaf.shape) == 0: + raise RuntimeError( + "the done_spec's leafs shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." + ) + self.output_spec["full_done_spec"] = value.to(device) + self._create_done_specs() + self._get_done_keys() finally: self.output_spec.lock_() # observation spec: observation specs belong to output_spec @property def observation_spec(self) -> CompositeSpec: - observation_spec = self.output_spec["_observation_spec"] + """Observation spec. + + Must be a :class:`torchrl.data.CompositeSpec` instance. + The keys listed in the spec are directly accessible after reset and step. + + In TorchRL, even though they are not properly speaking "observations" + all info, states, results of transforms etc. outputs from the environment are stored in the + ``observation_spec``. + + Therefore, ``"observation_spec"`` should be thought as + a generic data container for environment outputs that are not done or reward data. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = GymEnv("Pendulum-v1") + >>> env.observation_spec + CompositeSpec( + observation: BoundedTensorSpec( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([])) + + """ + observation_spec = self.output_spec["full_observation_spec"] if observation_spec is None: observation_spec = CompositeSpec(shape=self.batch_size, device=self.device) self.output_spec.unlock_() - self.output_spec["_observation_spec"] = observation_spec + self.output_spec["full_observation_spec"] = observation_spec self.output_spec.lock_() return observation_spec @@ -762,18 +1196,57 @@ def observation_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - self.output_spec["_observation_spec"] = value.to(device) + self.output_spec["full_observation_spec"] = value.to(device) finally: self.output_spec.lock_() + @property + def full_observation_spec(self) -> CompositeSpec: + return self.observation_spec + + @full_observation_spec.setter + def full_observation_spec(self, spec: CompositeSpec): + self.observation_spec = spec + # state spec: state specs belong to input_spec @property def state_spec(self) -> CompositeSpec: - state_spec = self.input_spec["_state_spec"] + """State spec. + + Must be a :class:`torchrl.data.CompositeSpec` instance. + The keys listed here should be provided as input alongside actions to the environment. + + In TorchRL, even though they are not properly speaking "state" + all inputs to the environment that are not actions are stored in the + ``state_spec``. + + Therefore, ``"state_spec"`` should be thought as + a generic data container for environment inputs that are not action data. + + Examples: + >>> from torchrl.envs import BraxEnv + >>> for envname in BraxEnv.available_envs: + ... break + >>> env = BraxEnv(envname) + >>> env.state_spec + CompositeSpec( + state: CompositeSpec( + pipeline_state: CompositeSpec( + q: UnboundedContinuousTensorSpec( + shape=torch.Size([15]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), + [...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + + """ + state_spec = self.input_spec["full_state_spec"] if state_spec is None: state_spec = CompositeSpec(shape=self.batch_size, device=self.device) self.input_spec.unlock_() - self.input_spec["_state_spec"] = state_spec + self.input_spec["full_state_spec"] = state_spec self.input_spec.lock_() return state_spec @@ -782,7 +1255,7 @@ def state_spec(self, value: CompositeSpec) -> None: try: self.input_spec.unlock_() if value is None: - self.input_spec["_state_spec"] = CompositeSpec( + self.input_spec["full_state_spec"] = CompositeSpec( device=self.device, shape=self.batch_size ) else: @@ -797,10 +1270,41 @@ def state_spec(self, value: CompositeSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - self.input_spec["_state_spec"] = value.to(device) + self.input_spec["full_state_spec"] = value.to(device) finally: self.input_spec.lock_() + @property + def full_state_spec(self) -> CompositeSpec: + """The full state spec. + + ``full_state_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + that contains all the state entries (ie, the input data that is not action). + + Examples: + >>> from torchrl.envs import BraxEnv + >>> for envname in BraxEnv.available_envs: + ... break + >>> env = BraxEnv(envname) + >>> env.full_state_spec + CompositeSpec( + state: CompositeSpec( + pipeline_state: CompositeSpec( + q: UnboundedContinuousTensorSpec( + shape=torch.Size([15]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), + [...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + + """ + return self.state_spec + + @full_state_spec.setter + def full_state_spec(self, spec: CompositeSpec) -> None: + self.state_spec = spec + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. @@ -810,6 +1314,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: Args: tensordict (TensorDictBase): Tensordict containing the action to be taken. + If the input tensordict contains a ``"next"`` entry, the values contained in it + will prevail over the newly computed values. This gives a mechanism + to override the underlying computations. Returns: the input tensordict, modified in place with the resulting observations, done state and reward @@ -818,26 +1325,78 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: """ # sanity check self._assert_tensordict_shape(tensordict) + next_preset = tensordict.get("next", None) + + next_tensordict = self._step(tensordict) + next_tensordict = self._step_proc_data(next_tensordict) + if next_preset is not None: + # tensordict could already have a "next" key + next_tensordict.update(next_preset) + tensordict.set("next", next_tensordict) + return tensordict - tensordict_out = self._step(tensordict) - # this tensordict should contain a "next" key - try: - next_tensordict_out = tensordict_out.get("next") - except KeyError: - raise RuntimeError( - "The value returned by env._step must be a tensordict where the " - "values at t+1 have been written under a 'next' entry. This " - f"tensordict couldn't be found in the output, got: {tensordict_out}." - ) - if tensordict_out is tensordict: - raise RuntimeError( - "EnvBase._step should return outplace changes to the input " - "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " - "tensordict.select()) inside _step before writing new tensors onto this new instance." - ) - + @classmethod + def _complete_done( + cls, done_spec: CompositeSpec, data: TensorDictBase + ) -> TensorDictBase: + """Completes the data structure at step time to put missing done keys.""" + # by default, if a done key is missing, it is assumed that it is False + # except in 2 cases: (1) there is a "done" but no "terminated" or (2) + # there is a "terminated" but no "done". + if done_spec.ndim: + leading_dim = data.shape[: -done_spec.ndim] + else: + leading_dim = data.shape + vals = {} + i = -1 + for i, (key, item) in enumerate(done_spec.items()): # noqa: B007 + val = data.get(key, None) + if isinstance(item, CompositeSpec): + cls._complete_done(item, val) + continue + shape = (*leading_dim, *item.shape) + if val is not None: + if val.shape != shape: + data.set(key, val.reshape(shape)) + vals[key] = val + + if len(vals) < i + 1: + # complete missing dones: we only want to do that if we don't have enough done values + data_keys = set(data.keys()) + done_spec_keys = set(done_spec.keys()) + for key, item in done_spec.items(False, True): + val = vals.get(key, None) + if ( + key == "done" + and val is not None + and "terminated" in done_spec_keys + and "terminated" not in data_keys + ): + if "truncated" in data_keys: + raise RuntimeError( + "Cannot infer the value of terminated when only done and truncated are present." + ) + data.set("terminated", val) + elif ( + key == "terminated" + and val is not None + and "done" in done_spec_keys + and "done" not in data_keys + ): + if "truncated" in data_keys: + done = val | data.get("truncated") + data.set("done", done) + else: + data.set("done", val) + elif val is None: + # we must keep this here: we only want to fill with 0s if we're sure + # done should not be copied to terminated or terminated to done + # in this case, just fill with 0s + data.set(key, item.zero(leading_dim)) + return data + + def _step_proc_data(self, next_tensordict_out): # TODO: Refactor this using reward spec - reward = next_tensordict_out.get(self.reward_key) # unsqueeze rewards if needed # the input tensordict may have more leading dimensions than the batch_size # e.g. in model-based contexts. @@ -848,46 +1407,47 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: if dims else next_tensordict_out.shape ) - expected_reward_shape = torch.Size( - [*leading_batch_size, *self.reward_spec.shape] - ) - actual_reward_shape = reward.shape - if actual_reward_shape != expected_reward_shape: - reward = reward.view(expected_reward_shape) - next_tensordict_out.set(self.reward_key, reward) - - # TODO: Refactor this using done spec - done = next_tensordict_out.get(self.done_key) - # unsqueeze done if needed - expected_done_shape = torch.Size([*leading_batch_size, *self.done_spec.shape]) - actual_done_shape = done.shape - if actual_done_shape != expected_done_shape: - done = done.view(expected_done_shape) - next_tensordict_out.set(self.done_key, done) - tensordict_out.set("next", next_tensordict_out) - - if self.run_type_checks: - for key in self._select_observation_keys(tensordict_out): - obs = tensordict_out.get(key) - self.observation_spec.type_check(obs, key) + for reward_key in self.reward_keys: + reward = next_tensordict_out.get(reward_key) + expected_reward_shape = torch.Size( + [ + *leading_batch_size, + *self.output_spec["full_reward_spec"][reward_key].shape, + ] + ) + actual_reward_shape = reward.shape + if actual_reward_shape != expected_reward_shape: + reward = reward.view(expected_reward_shape) + next_tensordict_out.set(reward_key, reward) - if ( - next_tensordict_out.get(self.reward_key).dtype - is not self.reward_spec.dtype - ): - raise TypeError( - f"expected reward.dtype to be {self.reward_spec.dtype} " - f"but got {tensordict_out.get(self.reward_key).dtype}" - ) + self._complete_done(self.full_done_spec, next_tensordict_out) - if next_tensordict_out.get(self.done_key).dtype is not self.done_spec.dtype: - raise TypeError( - f"expected done.dtype to be torch.bool but got {tensordict_out.get(self.done_key).dtype}" - ) - # tensordict could already have a "next" key - tensordict.update(tensordict_out) + if self.run_type_checks: + for key, spec in self.observation_spec.items(): + obs = next_tensordict_out.get(key) + spec.type_check(obs) + + for reward_key in self.reward_keys: + if ( + next_tensordict_out.get(reward_key).dtype + is not self.output_spec[ + unravel_key(("full_reward_spec", reward_key)) + ].dtype + ): + raise TypeError( + f"expected reward.dtype to be {self.output_spec[unravel_key(('full_reward_spec',reward_key))]} " + f"but got {next_tensordict_out.get(reward_key).dtype}" + ) - return tensordict + for done_key in self.done_keys: + if ( + next_tensordict_out.get(done_key).dtype + is not self.output_spec["full_done_spec", done_key].dtype + ): + raise TypeError( + f"expected done.dtype to be {self.output_spec['full_done_spec', done_key].dtype} but got {next_tensordict_out.get(done_key).dtype}" + ) + return next_tensordict_out def _get_in_keys_to_exclude(self, tensordict): if self._cache_in_keys is None: @@ -931,19 +1491,13 @@ def reset( a tensordict (or the input tensordict, if any), modified in place with the resulting observations. """ - if tensordict is not None and "_reset" in tensordict.keys(): + if tensordict is not None: self._assert_tensordict_shape(tensordict) - _reset = tensordict.get("_reset") - if _reset.shape[-len(self.done_spec.shape) :] != self.done_spec.shape: - raise RuntimeError( - "_reset flag in tensordict should follow env.done_spec" - ) - else: - _reset = None tensordict_reset = self._reset(tensordict, **kwargs) - if tensordict_reset.device != self.device: - tensordict_reset = tensordict_reset.to(self.device) + # We assume that this is done properly + # if tensordict_reset.device != self.device: + # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -955,28 +1509,38 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - if len(self.batch_size): - leading_dim = tensordict_reset.shape[: -len(self.batch_size)] - else: - leading_dim = tensordict_reset.shape - if self.done_spec is not None and self.done_key not in tensordict_reset.keys( - True, True - ): - tensordict_reset.set( - self.done_key, - self.done_spec.zero(leading_dim), - ) + self._complete_done(self.full_done_spec, tensordict_reset) + + if not self._allow_done_after_reset: + # we iterate over (reset_key, (done_key, truncated_key)) and check that all + # values where reset was true now have a done set to False. + # If no reset was present, all done and truncated must be False + for reset_key, done_key_group in zip( + self.reset_keys, self.done_keys_groups + ): + reset_value = ( + tensordict.get(reset_key, default=None) + if tensordict is not None + else None + ) + if reset_value is not None: + for done_key in done_key_group: + if tensordict_reset.get(done_key)[reset_value].any(): + raise RuntimeError( + f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed." + ) + else: + for done_key in done_key_group: + if tensordict_reset.get(done_key).any(): + raise RuntimeError( + f"Env done entry '{done_key}' was (partially) True after a call to reset(). This is not allowed." + ) - if (_reset is None and tensordict_reset.get(self.done_key).any()) or ( - _reset is not None and tensordict_reset.get(self.done_key)[_reset].any() - ): - raise RuntimeError( - f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." - ) if tensordict is not None: tensordict.update(tensordict_reset) else: tensordict = tensordict_reset + tensordict.exclude(*self.reset_keys, inplace=True) return tensordict def numel(self) -> int: @@ -1014,7 +1578,7 @@ def set_state(self): def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: if ( - self.batch_locked or self.batch_size != torch.Size([]) + self.batch_locked or self.batch_size != () ) and tensordict.batch_size != self.batch_size: raise RuntimeError( f"Expected a tensordict with shape==env.shape, " @@ -1044,7 +1608,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): f"Non batch-locked environment require the env batch-size to be either empty or to" f" match the tensordict one." ) - r = self.input_spec["_action_spec"].rand(shape) + r = self.input_spec["full_action_spec"].rand(shape) if tensordict is None: return r tensordict.update(r) @@ -1088,7 +1652,7 @@ def rollout( break_when_any_done: bool = True, return_contiguous: bool = True, tensordict: Optional[TensorDictBase] = None, - ) -> TensorDictBase: + ): """Executes a rollout in the environment. The function will stop as soon as one of the contained environments @@ -1236,41 +1800,113 @@ def policy(td): if auto_cast_to_device: tensordict = tensordict.to(env_device, non_blocking=True) tensordict = self.step(tensordict) - tensordicts.append(tensordict.clone(False)) - done = tensordict.get(("next", self.done_key)) - truncated = tensordict.get( - ("next", "truncated"), - default=torch.zeros((), device=done.device, dtype=torch.bool), - ) - done = done | truncated - if (break_when_any_done and done.any()) or i == max_steps - 1: + + if i == max_steps - 1: + # we don't truncated as one could potentially continue the run break tensordict = step_mdp( tensordict, keep_other=True, - exclude_action=True, + exclude_action=False, exclude_reward=True, - reward_key=self.reward_key, - action_key=self.action_key, - done_key=self.done_key, + reward_keys=self.reward_keys, + action_keys=self.action_keys, + done_keys=self.done_keys, ) - if not break_when_any_done and done.any(): - _reset = done.clone() - tensordict.set("_reset", _reset) - self.reset(tensordict) + # done and truncated are in done_keys + # We read if any key is done. + any_done = terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key=None if break_when_any_done else "_reset", + ) + if break_when_any_done and any_done: + break + if not break_when_any_done and any_done: + tensordict = self.reset(tensordict) if callback is not None: callback(self, tensordict) batch_size = self.batch_size if tensordict is None else tensordict.batch_size - out_td = torch.stack(tensordicts, len(batch_size)) if return_contiguous: out_td = out_td.contiguous() out_td.refine_names(..., "time") return out_td + @property + def reset_keys(self) -> List[NestedKey]: + """Returns a list of reset keys. + + Reset keys are keys that indicate partial reset, in batched, multitask or multiagent + settings. They are structured as ``(*prefix, "_reset")`` where ``prefix`` is + a (possibly empty) tuple of strings pointing to a tensordict location + where a done state can be found. + + The value of reset_keys is cached. + """ + reset_keys = self.__dict__.get("_reset_keys", None) + if reset_keys is not None: + return reset_keys + prefixes = set() + reset_keys = [] + + def prefix(key: NestedKey): + if isinstance(key, str): + return None + return key[:-1] + + def combine(prefix_key: tuple | None, key: str): + if prefix_key is None: + return key + return (*prefix_key, key) + + for done_key in self.done_keys: + prefix_key = prefix(done_key) + if prefix_key in prefixes: + continue + prefixes.add(prefix_key) + reset_keys.append(combine(prefix_key, "_reset")) + self.__dict__["_reset_keys"] = reset_keys + return reset_keys + + @property + def done_keys_groups(self): + """A list of done keys, grouped as the reset keys. + + This is a list of lists. The outer list has the length of reset keys, the + inner lists contain the done keys (eg, done and truncated) that can + be read to determine a reset when it is absent. + + The value of ``done_keys_groups`` is cached. + + """ + done_keys_sorted = self.__dict__.get("_done_keys_groups", None) + if done_keys_sorted is not None: + return done_keys_sorted + # done keys, sorted as reset keys + reset_keys = self.reset_keys + done_keys = [[] for _ in range(len(reset_keys))] + reset_keys_iter = iter(reset_keys) + done_keys_iter = iter(done_keys) + try: + curr_reset_key = next(reset_keys_iter) + curr_done_key = next(done_keys_iter) + except StopIteration: + return done_keys + + for done_key in self.done_keys: + while type(done_key) != type(curr_reset_key) or ( + isinstance(done_key, tuple) and done_key[:-1] != curr_reset_key[:-1] + ): # if they are string, they are at the same level + curr_reset_key = next(reset_keys_iter) + curr_done_key = next(done_keys_iter) + curr_done_key.append(done_key) + self.__dict__["_done_keys_groups"] = done_keys + return done_keys + def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]: for key in tensordict.keys(): if key.rfind("observation") >= 0: @@ -1305,13 +1941,11 @@ def fake_tensordict(self) -> TensorDictBase: """Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout.""" state_spec = self.state_spec observation_spec = self.observation_spec - action_spec = self.input_spec["_action_spec"] + action_spec = self.input_spec["full_action_spec"] # instantiates reward_spec if needed _ = self.reward_spec - reward_spec = self.output_spec["_reward_spec"] - # instantiates done_spec if needed - _ = self.done_spec - done_spec = self.output_spec["_done_spec"] + reward_spec = self.output_spec["full_reward_spec"] + full_done_spec = self.output_spec["full_done_spec"] fake_obs = observation_spec.zero() @@ -1324,20 +1958,23 @@ def fake_tensordict(self) -> TensorDictBase: fake_in_out = fake_input.update(fake_obs) fake_reward = reward_spec.zero() - fake_done = done_spec.zero() + fake_done = full_done_spec.zero() next_output = fake_obs.clone() next_output.update(fake_reward) next_output.update(fake_done) fake_in_out.update(fake_done.clone()) + if "next" not in fake_in_out.keys(): + fake_in_out.set("next", next_output) + else: + fake_in_out.get("next").update(next_output) - fake_td = fake_in_out.set("next", next_output) - fake_td.batch_size = self.batch_size - fake_td = fake_td.to(self.device) - return fake_td + fake_in_out.batch_size = self.batch_size + fake_in_out = fake_in_out.to(self.device) + return fake_in_out -class _EnvWrapper(EnvBase, metaclass=abc.ABCMeta): +class _EnvWrapper(EnvBase): """Abstract environment wrapper class. Unlike EnvBase, _EnvWrapper comes with a :obj:`_build_env` private method that will be called upon instantiation. @@ -1360,14 +1997,18 @@ def __init__( self, *args, dtype: Optional[np.dtype] = None, - device: DEVICE_TYPING = "cpu", + device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, + allow_done_after_reset: bool = False, **kwargs, ): + if device is None: + device = torch.device("cpu") super().__init__( device=device, dtype=dtype, batch_size=batch_size, + allow_done_after_reset=allow_done_after_reset, ) if len(args): raise ValueError( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index fa724c0ca10..79dc8c4ab64 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -115,11 +115,11 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ - _info_dict_reader: BaseInfoDictReader + _info_dict_reader: List[BaseInfoDictReader] @classmethod def __new__(cls, *args, **kwargs): - cls._info_dict_reader = None + cls._info_dict_reader = [] return super().__new__(cls, *args, _batch_locked=True, **kwargs) def read_action(self, action): @@ -133,28 +133,67 @@ def read_action(self, action): """ return self.action_spec.to_numpy(action, safe=False) - def read_done(self, done): + def read_done( + self, + terminated: bool | None = None, + truncated: bool | None = None, + done: bool | None = None, + ) -> Tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]: """Done state reader. - Reads a done state and returns a tuple containing: - - a done state to be set in the environment - - a boolean value indicating whether the frame_skip loop should be broken + In torchrl, a `"done"` signal means that a trajectory has reach its end, + either because it has been interrupted or because it is terminated. + Truncated means the episode has been interrupted early. + Terminated means the task is finished, the episode is completed. Args: - done (np.ndarray, boolean or other format): done state obtained from the environment + terminated (np.ndarray, boolean or other format): completion state + obtained from the environment. + ``"terminated"`` equates to ``"termination"`` in gymnasium: + the signal that the environment has reached the end of the + episode, any data coming after this should be considered as nonsensical. + Defaults to ``None``. + truncated (bool or None): early truncation signal. + Defaults to ``None``. + done (bool or None): end-of-trajectory signal. + This should be the fallback value of envs which do not specify + if the ``"done"`` entry points to a ``"terminated"`` or + ``"truncated"``. + Defaults to ``None``. + + Returns: a tuple with 4 boolean / tensor values, + - a terminated state, + - a truncated state, + - a done state, + - a boolean value indicating whether the frame_skip loop should be broken. """ - return done, done + if truncated is not None and done is None: + done = truncated | terminated + elif truncated is None and done is None: + done = terminated + do_break = done.any() if not isinstance(done, bool) else done + if isinstance(done, bool): + done = [done] + if terminated is not None: + terminated = [terminated] + if truncated is not None: + truncated = [truncated] + return ( + terminated, + truncated, + done, + do_break.any() if not isinstance(do_break, bool) else do_break, + ) - def read_reward(self, total_reward, step_reward): - """Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two. + def read_reward(self, reward): + """Reads the reward and maps it to the reward space. Args: - total_reward (torch.Tensor or TensorDict): total reward so far in the step - step_reward (reward in the format provided by the inner env): reward of this particular step + reward (torch.Tensor or TensorDict): reward to be mapped. """ - return total_reward + self.reward_spec.encode(step_reward, ignore_device=True) + return self.reward_spec.encode(reward, ignore_device=True) def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] @@ -166,7 +205,6 @@ def read_obs( """ if isinstance(observations, dict): - observations = {key: value for key, value in observations.items()} if "state" in observations and "observation" not in observations: # we rename "state" in "observation" as "observation" is the conventional name # for single observation in torchrl. @@ -184,102 +222,125 @@ def read_obs( return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - action = tensordict.get("action") + action = tensordict.get(self.action_key) action_np = self.read_action(action) reward = 0 for _ in range(self.wrapper_frame_skip): - obs, _reward, done, *info = self._output_transform( + obs, _reward, terminated, truncated, done, info = self._output_transform( self._env.step(action_np) ) if isinstance(obs, list) and len(obs) == 1: # Until gym 0.25.2 we had rendered frames returned in lists of length 1 obs = obs[0] - if len(info) == 2: - # gym 0.26 - truncation, info = info - done = done | truncation - elif len(info) == 1: - info = info[0] - elif len(info) == 0: - info = None - else: - raise ValueError( - "the environment output is expected to be either" - "obs, reward, done, truncation, info (gym >= 0.26) or " - f"obs, reward, done, info. Got info with types = ({[type(x) for x in info]})" - ) if _reward is None: _reward = self.reward_spec.zero() - reward = self.read_reward(reward, _reward) - - if isinstance(done, bool) or ( - isinstance(done, np.ndarray) and not len(done) - ): - done = torch.tensor([done]) + reward = reward + _reward - done, do_break = self.read_done(done) + terminated, truncated, done, do_break = self.read_done( + terminated=terminated, truncated=truncated, done=done + ) if do_break: break + reward = self.read_reward(reward) obs_dict = self.read_obs(obs) if reward is None: reward = torch.tensor(np.nan).expand(self.reward_spec.shape) - # reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) - # done = self._to_tensor(done, dtype=torch.bool) - obs_dict["reward"] = reward - obs_dict["done"] = done - obs_dict = {("next", key): val for key, val in obs_dict.items()} - tensordict_out = TensorDict( - obs_dict, batch_size=tensordict.batch_size, device=self.device - ) + obs_dict[self.reward_key] = reward + + # if truncated/terminated is not in the keys, we just don't pass it even if it + # is defined. + if terminated is None: + terminated = done + if truncated is not None and "truncated" in self.done_keys: + obs_dict["truncated"] = truncated + obs_dict["done"] = done + obs_dict["terminated"] = terminated - if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out.get("next")) + tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size) + if self.info_dict_reader and info is not None: + if not isinstance(info, dict): + warnings.warn( + f"Expected info to be a dictionary but got a {type(info)} with values {str(info)[:100]}." + ) + else: + for info_dict_reader in self.info_dict_reader: + out = info_dict_reader(info, tensordict_out) + if out is not None: + tensordict_out = out + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: - reset_data = self._env.reset(**kwargs) - if not isinstance(reset_data, tuple): - reset_data = (reset_data,) - obs, *other = self._output_transform(reset_data) - info = None - if len(other) == 1: - info = other[0] + obs, info = self._reset_output_transform(self._env.reset(**kwargs)) + + source = self.read_obs(obs) tensordict_out = TensorDict( - source=self.read_obs(obs), + source=source, batch_size=self.batch_size, - device=self.device, ) - if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out) - elif info is None and self.info_dict_reader is not None: + if self.info_dict_reader and info is not None: + for info_dict_reader in self.info_dict_reader: + out = info_dict_reader(info, tensordict_out) + if out is not None: + tensordict_out = out + elif info is None and self.info_dict_reader: # populate the reset with the items we have not seen from info - for key, item in self.observation_spec.items(): - if key not in tensordict_out.keys(): + for key, item in self.observation_spec.items(True, True): + if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() - - tensordict_out.setdefault( - "done", - self.done_spec.zero(), - ) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out - def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: - """To be overwritten when step_outputs differ from Tuple[Observation: Union[np.ndarray, dict], reward: Number, done:Bool].""" - if not isinstance(step_outputs_tuple, tuple): - raise TypeError( - f"Expected step_outputs_tuple type to be Tuple but got {type(step_outputs_tuple)}" - ) - return step_outputs_tuple + @abc.abstractmethod + def _output_transform( + self, step_outputs_tuple: Tuple + ) -> Tuple[ + Any, + float | np.ndarray, + bool | np.ndarray | None, + bool | np.ndarray | None, + bool | np.ndarray | None, + dict, + ]: + """A method to read the output of the env step. + + Must return a tuple: (obs, reward, terminated, truncated, done, info). + If only one end-of-trajectory is passed, it is interpreted as ``"truncated"``. + An attempt to retrieve ``"truncated"`` from the info dict is also undertaken. + If 2 are passed (like in gymnasium), we interpret them as ``"terminated", + "truncated"`` (``"truncated"`` meaning that the trajectory has been + interrupted early), and ``"done"`` is the union of the two, + ie. the unspecified end-of-trajectory signal. + + These three concepts have different usage: + + - ``"terminated"`` indicated the final stage of a Markov Decision + Process. It means that one should not pay attention to the + upcoming observations (eg., in value functions) as they should be + regarded as not valid. + - ``"truncated"`` means that the environment has reached a stage where + we decided to stop the collection for some reason but the next + observation should not be discarded. If it were not for this + arbitrary decision, the collection could have proceeded further. + - ``"done"`` is either one or the other. It is to be interpreted as + "a reset should be called before the next step is undertaken". + + """ + ... + + @abc.abstractmethod + def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: + ... def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeEnv: """Sets an info_dict_reader function. @@ -306,9 +367,12 @@ def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeE >>> assert "my_info_key" in tensordict.keys() """ - self.info_dict_reader = info_dict_reader - for info_key, spec in info_dict_reader.info_spec.items(): - self.observation_spec[info_key] = spec.to(self.device) + self.info_dict_reader.append(info_dict_reader) + if isinstance(info_dict_reader, BaseInfoDictReader): + # if we have a BaseInfoDictReader, we know what the specs will be + # In other cases (eg, RoboHive) we will need to figure it out empirically. + for info_key, spec in info_dict_reader.info_spec.items(): + self.observation_spec[info_key] = spec.to(self.device) return self def __repr__(self) -> str: @@ -322,4 +386,10 @@ def info_dict_reader(self): @info_dict_reader.setter def info_dict_reader(self, value: callable): - self._info_dict_reader = value + warnings.warn( + f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " + f"This method will append a reader to the list of existing readers (if any). " + f"Setting info_dict_reader directly will be soon deprecated.", + category=DeprecationWarning, + ) + self._info_dict_reader.append(value) diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 7bec24cb17b..9121ea4c677 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -2,3 +2,23 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .brax import BraxEnv, BraxWrapper +from .dm_control import DMControlEnv, DMControlWrapper +from .envpool import MultiThreadedEnv, MultiThreadedEnvWrapper +from .gym import ( + gym_backend, + GymEnv, + GymWrapper, + MOGymEnv, + MOGymWrapper, + set_gym_backend, +) +from .habitat import HabitatEnv +from .isaacgym import IsaacGymEnv, IsaacGymWrapper +from .jumanji import JumanjiEnv, JumanjiWrapper +from .openml import OpenMLEnv +from .pettingzoo import PettingZooEnv, PettingZooWrapper +from .robohive import RoboHiveEnv +from .smacv2 import SMACv2Env, SMACv2Wrapper +from .vmas import VmasEnv, VmasWrapper diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 06c5e4db28a..f9a9d555c29 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util from typing import Dict, Optional, Union @@ -14,31 +15,26 @@ UnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper - -try: - import brax - import brax.envs - import jax - from torchrl.envs.libs.jax_utils import ( - _extract_spec, - _ndarray_to_tensor, - _object_to_tensordict, - _tensor_to_ndarray, - _tensordict_to_object, - _tree_flatten, - _tree_reshape, - ) - - _has_brax = True - IMPORT_ERR = "" -except ImportError as err: - _has_brax = False - IMPORT_ERR = str(err) +from torchrl.envs.utils import _classproperty + +_has_brax = importlib.util.find_spec("brax") is not None +from torchrl.envs.libs.jax_utils import ( + _extract_spec, + _ndarray_to_tensor, + _object_to_tensordict, + _tensor_to_ndarray, + _tensordict_to_object, + _tree_flatten, + _tree_reshape, +) def _get_envs(): if not _has_brax: - return [] + raise ImportError("BRAX is not installed in your virtual environment.") + + import brax.envs + return list(brax.envs._envs.keys()) @@ -74,13 +70,39 @@ class BraxWrapper(_EnvWrapper): """ git_url = "https://github.com/google/brax" - available_envs = _get_envs() + + @_classproperty + def available_envs(cls): + if not _has_brax: + return + yield from _get_envs() + libname = "brax" - @property - def lib(self): + _lib = None + _jax = None + + @_classproperty + def lib(cls): + if cls._lib is not None: + return cls._lib + + import brax + import brax.envs + + cls._lib = brax return brax + @_classproperty + def jax(cls): + if cls._jax is not None: + return cls._jax + + import jax + + cls._jax = jax + return jax + def __init__(self, env=None, categorical_action_encoding=False, **kwargs): if env is not None: kwargs["env"] = env @@ -89,6 +111,8 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): + brax = self.lib + if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -114,7 +138,9 @@ def _build_env( raise NotImplementedError("TODO") return env - def _make_state_spec(self, env: "brax.envs.env.Env"): + def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 + jax = self.jax + key = jax.random.PRNGKey(0) state = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) @@ -123,8 +149,8 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.action_spec = BoundedTensorSpec( - minimum=-1, - maximum=1, + low=-1, + high=1, shape=( *self.batch_size, env.action_size, @@ -154,6 +180,8 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.observation_spec["state"] = state_spec.clone() def _make_state_example(self): + jax = self.jax + key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) @@ -161,17 +189,20 @@ def _make_state_example(self): return state def _init_env(self) -> Optional[int]: + jax = self.jax self._key = None self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset)) self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step)) self._state_example = self._make_state_example() def _set_seed(self, seed: int): + jax = self.jax if seed is None: raise Exception("Brax requires an integer seed.") self._key = jax.random.PRNGKey(seed) def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: + jax = self.jax # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) @@ -192,6 +223,7 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: "observation": state.get("obs"), # "reward": reward, "done": done, + "terminated": done.clone(), "state": state, }, batch_size=self.batch_size, @@ -227,6 +259,7 @@ def _step_without_grad(self, tensordict: TensorDictBase): "observation": next_state.get("obs"), "reward": reward, "done": done, + "terminated": done.clone(), "state": next_state, }, batch_size=self.batch_size, @@ -265,6 +298,7 @@ def _step_with_grad(self, tensordict: TensorDictBase): "observation": next_obs, "reward": next_reward, "done": next_done, + "terminated": next_done, "state": next_state, }, batch_size=self.batch_size, @@ -282,7 +316,6 @@ def _step( out = self._step_with_grad(tensordict) else: out = self._step_without_grad(tensordict) - out = out.select().set("next", out) return out @@ -305,13 +338,13 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "brax.envs.env.Env": + ) -> "brax.envs.env.Env": # noqa: F821 if not _has_brax: - raise RuntimeError( + raise ImportError( f"brax not found, unable to create {env_name}. " f"Consider downloading and installing brax from" f" {self.git_url}" - ) from IMPORT_ERR + ) from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) requires_grad = kwargs.pop("requires_grad", False) @@ -341,6 +374,7 @@ def __repr__(self) -> str: class _BraxEnvStep(torch.autograd.Function): @staticmethod def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values): + import jax # convert tensors to ndarrays state_obj = _tensordict_to_object(state_td, env._state_example) @@ -376,15 +410,6 @@ def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values): @staticmethod def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): - # build gradient tensordict with zeros in fields with no grad - # if grad_next_reward is None: - # raise RuntimeError("grad_next_reward") - # grad_next_reward = torch.zeros((*ctx.env.batch_size, 1), device=ctx.env.device) - # if grad_next_obs is None: - # raise RuntimeError("grad_next_obs") - # if any(val is None for val in grad_next_qp_values): - # raise RuntimeError("grad_next_qp_values") - pipeline_state = dict( zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values) ) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 56bf48fa0ac..89b402bd904 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -5,24 +5,28 @@ from __future__ import annotations import collections + +import importlib import os from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch +from torchrl._utils import VERBOSE + from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, + DiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) -from ..._utils import VERBOSE - -from ...data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict -from ..gym_like import GymLikeEnv +from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict +from torchrl.envs.gym_like import GymLikeEnv +from torchrl.envs.utils import _classproperty if torch.cuda.device_count() > 1: n = torch.cuda.device_count() - 1 @@ -30,20 +34,7 @@ if VERBOSE: print("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"]) -try: - - import dm_control - import dm_env - from dm_control import suite - from dm_control.suite.wrappers import pixels - - _has_dmc = True - -except ImportError as err: - _has_dmc = False - IMPORT_ERR = err -else: - IMPORT_ERR = None +_has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None __all__ = ["DMControlEnv", "DMControlWrapper"] @@ -53,6 +44,8 @@ def _dmcontrol_to_torchrl_spec_transform( dtype: Optional[torch.dtype] = None, device: DEVICE_TYPING = None, ) -> TensorSpec: + import dm_env + if isinstance(spec, collections.OrderedDict): spec = { k: _dmcontrol_to_torchrl_spec_transform(item, device=device) @@ -67,8 +60,8 @@ def _dmcontrol_to_torchrl_spec_transform( shape = torch.Size([1]) return BoundedTensorSpec( shape=shape, - minimum=spec.minimum, - maximum=spec.maximum, + low=spec.minimum, + high=spec.maximum, dtype=dtype, device=device, ) @@ -90,8 +83,10 @@ def _dmcontrol_to_torchrl_spec_transform( def _get_envs(to_dict: bool = True) -> Dict[str, Any]: - if not _has_dmc: - return {} + if not _has_dm_control: + raise ImportError("Cannot find dm_control in virtual environment.") + from dm_control import suite + if not to_dict: return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA) d = {} @@ -101,7 +96,7 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]: for tup in suite.EXTRA: env_name = tup[0] d.setdefault(env_name, []).append(tup[1]) - return d + return d.items() def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: @@ -130,7 +125,18 @@ class DMControlWrapper(GymLikeEnv): git_url = "https://github.com/deepmind/dm_control" libname = "dm_control" - available_envs = _get_envs() + + @_classproperty + def available_envs(cls): + if not _has_dm_control: + return + yield from _get_envs() + + @property + def lib(self): + import dm_control + + return dm_control def __init__(self, env=None, **kwargs): if env is not None: @@ -151,6 +157,8 @@ def _build_env( self.pixels_only = pixels_only if from_pixels: + from dm_control.suite.wrappers import pixels + self._set_egl_device(self.device) self.render_kwargs = {"camera_id": camera_id} if render_kwargs is not None: @@ -174,12 +182,23 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 reward_spec.shape = torch.Size([1]) self.reward_spec = reward_spec # populate default done spec - _ = self.done_spec + done_spec = DiscreteTensorSpec( + n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device + ) + self.done_spec = CompositeSpec( + done=done_spec.clone(), + truncated=done_spec.clone(), + terminated=done_spec.clone(), + device=self.device, + ) self.action_spec = _dmcontrol_to_torchrl_spec_transform( self._env.action_spec(), device=self.device ) def _check_kwargs(self, kwargs: Dict): + dm_control = self.lib + from dm_control.suite.wrappers import pixels + if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -207,6 +226,8 @@ def _init_env(self, seed: Optional[int] = None) -> Optional[int]: return seed def _set_seed(self, _seed: Optional[int]) -> Optional[int]: + from dm_control.suite.wrappers import pixels + if _seed is None: return None random_state = np.random.RandomState(_seed) @@ -223,14 +244,27 @@ def _set_seed(self, _seed: Optional[int]) -> Optional[int]: def _output_transform( self, timestep_tuple: Tuple["TimeStep"] # noqa: F821 - ) -> Tuple[np.ndarray, float, bool]: + ) -> Tuple[np.ndarray, float, bool, bool, dict]: if type(timestep_tuple) is not tuple: timestep_tuple = (timestep_tuple,) reward = timestep_tuple[0].reward - done = False # dm_control envs are non-terminating + done = truncated = terminated = False # dm_control envs are non-terminating observation = timestep_tuple[0].observation - return observation, reward, done + info = {} + + return observation, reward, terminated, truncated, done, info + + def _reset_output_transform(self, reset_data): + ( + observation, + reward, + terminated, + truncated, + done, + info, + ) = self._output_transform(reset_data) + return observation, info def __repr__(self) -> str: return ( @@ -262,7 +296,7 @@ def __init__(self, env_name, task_name, **kwargs): if not _has_dmc: raise ImportError( "dm_control python package was not found. Please install this dependency." - ) from IMPORT_ERR + ) kwargs["env_name"] = env_name kwargs["task_name"] = task_name super().__init__(**kwargs) @@ -274,6 +308,8 @@ def _build_env( _seed: Optional[int] = None, **kwargs, ): + from dm_control import suite + self.env_name = env_name self.task_name = task_name @@ -285,7 +321,7 @@ def _build_env( del kwargs["pixels_only"] if not _has_dmc: - raise RuntimeError( + raise ImportError( f"dm_control not found, unable to create {env_name}:" f" {task_name}. Consider downloading and installing " f"dm_control from {self.git_url}" @@ -314,9 +350,10 @@ def _check_kwargs(self, kwargs: Dict): env_name = kwargs["env_name"] if "task_name" in kwargs: task_name = kwargs["task_name"] + available_envs = dict(self.available_envs) if ( - env_name not in self.available_envs - or task_name not in self.available_envs[env_name] + env_name not in available_envs + or task_name not in available_envs[env_name] ): raise RuntimeError( f"{env_name} with task {task_name} is unknown in {self.libname}" diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py new file mode 100644 index 00000000000..9774e0627e0 --- /dev/null +++ b/torchrl/envs/libs/envpool.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import importlib +import logging +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch + +from tensordict import TensorDict, TensorDictBase +from torchrl.data import ( + CompositeSpec, + DiscreteTensorSpec, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.utils import _classproperty + +_has_envpool = importlib.util.find_spec("envpool") is not None + + +class MultiThreadedEnvWrapper(_EnvWrapper): + """Wrapper for envpool-based multithreaded environments.""" + + _verbose: bool = False + + @_classproperty + def lib(cls): + import envpool + + return envpool + + def __init__( + self, + env: Optional["envpool.python.envpool.EnvPoolMixin"] = None, # noqa: F821 + **kwargs, + ): + if not _has_envpool: + raise ImportError( + "envpool python package or one of its dependencies (gym, treevalue) were not found. Please install these dependencies." + ) + if env is not None: + kwargs["env"] = env + self.num_workers = env.config["num_envs"] + # For synchronous mode batch size is equal to the number of workers + self.batch_size = torch.Size([self.num_workers]) + super().__init__(**kwargs) + + # Buffer to keep the latest observation for each worker + # It's a TensorDict when the observation consists of several variables, e.g. "position" and "velocity" + self.obs: Union[torch.tensor, TensorDict] = self.observation_spec.zero() + + def _check_kwargs(self, kwargs: Dict): + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + import envpool + + if not isinstance(env, (envpool.python.envpool.EnvPoolMixin,)): + raise TypeError("env is not of type 'envpool.python.envpool.EnvPoolMixin'.") + + def _build_env(self, env: "envpool.python.envpool.EnvPoolMixin"): # noqa: F821 + return env + + def _make_specs( + self, env: "envpool.python.envpool.EnvPoolMixin" # noqa: F821 + ) -> None: # noqa: F821 + from torchrl.envs.libs.gym import set_gym_backend + + with set_gym_backend("gym"): + self.action_spec = self._get_action_spec() + output_spec = self._get_output_spec() + self.observation_spec = output_spec["full_observation_spec"] + self.reward_spec = output_spec["full_reward_spec"] + self.done_spec = output_spec["full_done_spec"] + + def _init_env(self) -> Optional[int]: + pass + + def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: + if tensordict is not None: + reset_workers = tensordict.get("_reset", None) + else: + reset_workers = None + if reset_workers is not None: + reset_data = self._env.reset(np.where(reset_workers.cpu().numpy())[0]) + else: + reset_data = self._env.reset() + tensordict_out = self._transform_reset_output(reset_data, reset_workers) + self.is_closed = False + return tensordict_out + + @torch.no_grad() + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + action = tensordict.get(self.action_key) + # Action needs to be moved to CPU and converted to numpy before being passed to envpool + action = action.to(torch.device("cpu")) + step_output = self._env.step(action.numpy()) + tensordict_out = self._transform_step_output(step_output) + return tensordict_out + + def _get_action_spec(self) -> TensorSpec: + # local import to avoid importing gym in the script + from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform + + # Envpool provides Gym-compatible specs as env.spec.action_space and + # DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones. + + # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase + action_spec = _gym_to_torchrl_spec_transform( + self._env.spec.action_space, + device=self.device, + categorical_action_encoding=True, + ) + action_spec = self._add_shape_to_spec(action_spec) + return action_spec + + def _get_output_spec(self) -> TensorSpec: + return CompositeSpec( + full_observation_spec=self._get_observation_spec(), + full_reward_spec=self._get_reward_spec(), + full_done_spec=self._get_done_spec(), + shape=(self.num_workers,), + device=self.device, + ) + + def _get_observation_spec(self) -> TensorSpec: + # local import to avoid importing gym in the script + from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform + + # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase + observation_spec = _gym_to_torchrl_spec_transform( + self._env.spec.observation_space, + device=self.device, + categorical_action_encoding=True, + ) + observation_spec = self._add_shape_to_spec(observation_spec) + if isinstance(observation_spec, CompositeSpec): + return observation_spec + return CompositeSpec( + observation=observation_spec, + shape=(self.num_workers,), + device=self.device, + ) + + def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec: + return spec.expand((self.num_workers, *spec.shape)) + + def _get_reward_spec(self) -> TensorSpec: + return UnboundedContinuousTensorSpec( + device=self.device, + shape=self.batch_size, + ) + + def _get_done_spec(self) -> TensorSpec: + spec = DiscreteTensorSpec( + 2, + device=self.device, + shape=self.batch_size, + dtype=torch.bool, + ) + return CompositeSpec( + done=spec, + truncated=spec.clone(), + terminated=spec.clone(), + shape=self.batch_size, + device=self.device, + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_workers={self.num_workers}, device={self.device})" + + def _transform_reset_output( + self, + envpool_output: Tuple[ + Union["treevalue.TreeValue", np.ndarray], Any # noqa: F821 + ], + reset_workers: Optional[torch.Tensor], + ): + """Process output of envpool env.reset.""" + import treevalue + + observation, _ = envpool_output + if reset_workers is not None: + # Only specified workers were reset - need to set observation buffer values only for them + if isinstance(observation, treevalue.TreeValue): + # If observation contain several fields, it will be returned as treevalue.TreeValue. + # Convert to treevalue.FastTreeValue to allow indexing + observation = treevalue.FastTreeValue(observation) + self.obs[reset_workers] = self._treevalue_or_numpy_to_tensor_or_dict( + observation + ) + else: + # All workers were reset - rewrite the whole observation buffer + self.obs = TensorDict( + self._treevalue_or_numpy_to_tensor_or_dict(observation), + self.batch_size, + device=self.device, + ) + + obs = self.obs.clone(False) + obs.update(self.full_done_spec.zero()) + return obs + + def _transform_step_output( + self, envpool_output: Tuple[Any, Any, Any, ...] + ) -> TensorDict: + """Process output of envpool env.step.""" + out = envpool_output + if len(out) == 4: + obs, reward, done, info = out + terminated = done + truncated = info.get("TimeLimit.truncated", done * 0) + elif len(out) == 5: + obs, reward, terminated, truncated, info = out + done = terminated | truncated + else: + raise TypeError( + f"The output of step was had {len(out)} elements, but only 4 or 5 are supported." + ) + obs = self._treevalue_or_numpy_to_tensor_or_dict(obs) + reward_and_done = {self.reward_key: torch.tensor(reward)} + reward_and_done["done"] = done + reward_and_done["terminated"] = terminated + reward_and_done["truncated"] = truncated + obs.update(reward_and_done) + self.obs = tensordict_out = TensorDict( + obs, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict_out + + def _treevalue_or_numpy_to_tensor_or_dict( + self, x: Union["treevalue.TreeValue", np.ndarray] # noqa: F821 + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """Converts observation returned by EnvPool. + + EnvPool step and reset return observation as a numpy array or a TreeValue of numpy arrays, which we convert + to a tensor or a dictionary of tensors. Currently only supports depth 1 trees, but can easily be extended to + arbitrary depth if necessary. + """ + import treevalue + + if isinstance(x, treevalue.TreeValue): + ret = self._treevalue_to_dict(x) + elif not isinstance(x, dict): + ret = {"observation": torch.tensor(x)} + else: + ret = x + return ret + + def _treevalue_to_dict( + self, tv: "treevalue.TreeValue" # noqa: F821 + ) -> Dict[str, Any]: + """Converts TreeValue to a dictionary. + + Currently only supports depth 1 trees, but can easily be extended to arbitrary depth if necessary. + """ + import treevalue + + return {k[0]: torch.tensor(v) for k, v in treevalue.flatten(tv)} + + def _set_seed(self, seed: Optional[int]): + if seed is not None: + print( + "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\ + supported by envpool. Please create a new environment, passing the seed to the constructor." + ) + + +class MultiThreadedEnv(MultiThreadedEnvWrapper): + """Multithreaded execution of environments based on EnvPool. + + An alternative to ParallelEnv based on multithreading. It's faster, as it doesn't require new process spawning, but + less flexible, as it only supports environments implemented in EnvPool library. + Currently only supports synchronous execution mode, when the batch size is equal to the number of workers, see + https://envpool.readthedocs.io/en/latest/content/python_interface.html#batch-size. + + >>> env = MultiThreadedEnv(num_workers=3, env_name="Pendulum-v1") + >>> env.reset() + >>> env.rand_step() + >>> env.rollout(5) + >>> env.close() + + Args: + num_workers: number of worker threads to create. + env_name: name of the environment, corresponding to task_id in EnvPool. + create_env_kwargs: additional arguments which will be passed to envpool.make. + """ + + def __init__( + self, + num_workers: int, + env_name: str, + create_env_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.env_name = env_name.replace("ALE/", "") # Naming convention of EnvPool + self.num_workers = num_workers + self.batch_size = torch.Size([num_workers]) + self.create_env_kwargs = create_env_kwargs or {} + + kwargs["num_workers"] = num_workers + kwargs["env_name"] = self.env_name + kwargs["create_env_kwargs"] = create_env_kwargs + super().__init__(**kwargs) + + def _build_env( + self, + env_name: str, + num_workers: int, + create_env_kwargs: Optional[Dict[str, Any]], + ) -> Any: + import envpool + + create_env_kwargs = create_env_kwargs or {} + env = envpool.make( + task_id=env_name, + env_type="gym", + num_envs=num_workers, + gym_reset_return_info=True, + **create_env_kwargs, + ) + return super()._build_env(env) + + def _set_seed(self, seed: Optional[int]): + """Library EnvPool only supports setting a seed by recreating the environment.""" + if seed is not None: + logging.debug("Recreating EnvPool environment to set seed.") + self.create_env_kwargs["seed"] = seed + self._env = self._build_env( + env_name=self.env_name, + num_workers=self.num_workers, + create_env_kwargs=self.create_env_kwargs, + ) + + def _check_kwargs(self, kwargs: Dict): + for arg in ["num_workers", "env_name", "create_env_kwargs"]: + if arg not in kwargs: + raise TypeError(f"Expected '{arg}' to be part of kwargs") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(env={self.env_name}, num_workers={self.num_workers}, device={self.device})" diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 3a15d865602..fbec9c4f657 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -6,15 +6,14 @@ import warnings from copy import copy from types import ModuleType -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from warnings import warn +import numpy as np import torch +from packaging import version -try: - from torch.utils._contextlib import _DecoratorContextManager -except ModuleNotFoundError: - from torchrl._utils import _DecoratorContextManager +from tensordict import TensorDictBase from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( @@ -29,10 +28,22 @@ UnboundedContinuousTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict +from torchrl.envs.batched_envs import CloudpickleWrapper +from torchrl.envs.common import _EnvPostInit + +from torchrl.envs.gym_like import ( + BaseInfoDictReader, + default_info_dict_reader, + GymLikeEnv, +) -from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv from torchrl.envs.utils import _classproperty +try: + from torch.utils._contextlib import _DecoratorContextManager +except ModuleNotFoundError: + from torchrl._utils import _DecoratorContextManager + DEFAULT_GYM = None IMPORT_ERROR = None # check gym presence without importing it @@ -41,6 +52,7 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None _has_mo = importlib.util.find_spec("mo_gymnasium") is not None +_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None class set_gym_backend(_DecoratorContextManager): @@ -90,12 +102,11 @@ def __init__(self, backend): self.backend = backend def _call(self): + """Sets the backend as default.""" global DEFAULT_GYM DEFAULT_GYM = self.backend - # implement_for.reset() - setters = copy(implement_for._setters) found_setter = False - for setter in setters: + for setter in copy(implement_for._setters): check_module = ( callable(setter.module_name) and setter.module_name.__name__ == self.backend.__name__ @@ -104,21 +115,26 @@ def _call(self): self.backend.__version__, setter.from_version, setter.to_version ) if check_module and check_version: - setter(setter.fn) + setter.module_set() found_setter = True + # we keep only the setters we need. This is safe because a copy is saved under self._setters_saved if not found_setter: raise ImportError( f"could not set anything related to gym backend " - f"{self.backend.__name__} with version={self.backend.__version__}." + f"{self.backend.__name__} with version={self.backend.__version__}. " + f"Check that the gym versions match!" ) def __enter__(self): - self._setters = copy(implement_for._setters) + # we save a complete list of setters as well as whether they should be set. + # we want the full list becasue we want to be able to nest the calls to set_gym_backend. + # we also want to keep track of which ones are set to reproduce what was set before. + self._setters_saved = copy(implement_for._implementations) self._call() def __exit__(self, exc_type, exc_val, exc_tb): - implement_for.reset(setters=self._setters) - delattr(self, "_setters") + implement_for.reset(setters_dict=self._setters_saved) + delattr(self, "_setters_saved") def clone(self): # override this method if your children class takes __init__ parameters @@ -179,18 +195,38 @@ def gym_backend(submodule=None): def _gym_to_torchrl_spec_transform( - spec, dtype=None, device="cpu", categorical_action_encoding=False + spec, + dtype=None, + device="cpu", + categorical_action_encoding=False, + remap_state_to_observation: bool = True, ) -> TensorSpec: """Maps the gym specs to the TorchRL specs. - By convention, 'state' keys of Dict specs will be renamed "observation" to match the - default TorchRL keys. + Args: + spec: the gym space to transform + dtype: a dtype to use for the spec. Defaults to`spec.dtype`. + device: the device for the spec. Defaults to "cpu". + categorical_action_encoding: whether discrete spaces should be mapped to categorical or one-hot. + Defaults to one-hot. + remap_state_to_observation: whether to rename the 'state' key of Dict specs to "observation". Default is true. """ - gym = gym_backend() - if isinstance(spec, gym.spaces.tuple.Tuple): - raise NotImplementedError("gym.spaces.tuple.Tuple mapping not yet implemented") - if isinstance(spec, gym.spaces.discrete.Discrete): + gym_spaces = gym_backend("spaces") + if isinstance(spec, gym_spaces.tuple.Tuple): + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + s, + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for s in spec + ], + 0, + ) + if isinstance(spec, gym_spaces.discrete.Discrete): action_space_cls = ( DiscreteTensorSpec if categorical_action_encoding @@ -202,22 +238,39 @@ def _gym_to_torchrl_spec_transform( else torch.long ) return action_space_cls(spec.n, device=device, dtype=dtype) - elif isinstance(spec, gym.spaces.multi_binary.MultiBinary): + elif isinstance(spec, gym_spaces.multi_binary.MultiBinary): return BinaryDiscreteTensorSpec( spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] ) - elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): - dtype = ( - numpy_to_torch_dtype_dict[spec.dtype] - if categorical_action_encoding - else torch.long - ) - return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) - if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + elif isinstance(spec, gym_spaces.multi_discrete.MultiDiscrete): + if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1: + dtype = ( + numpy_to_torch_dtype_dict[spec.dtype] + if categorical_action_encoding + else torch.long + ) + + return ( + MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + if categorical_action_encoding + else MultiOneHotDiscreteTensorSpec( + spec.nvec, device=device, dtype=dtype + ) + ) + + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + spec[i], + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for i in range(len(spec.nvec)) + ], + 0, ) - elif isinstance(spec, gym.spaces.Box): + elif isinstance(spec, gym_spaces.Box): shape = spec.shape if not len(shape): shape = torch.Size([1]) @@ -241,7 +294,11 @@ def _gym_to_torchrl_spec_transform( spec_out = {} for k in spec.keys(): key = k - if k == "state" and "observation" not in spec.keys(): + if ( + remap_state_to_observation + and k == "state" + and "observation" not in spec.keys() + ): # we rename "state" in "observation" as "observation" is the conventional name # for single observation in torchrl. # naming it 'state' will result in envs that have a different name for the state vector @@ -251,13 +308,15 @@ def _gym_to_torchrl_spec_transform( spec[k], device=device, categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, ) return CompositeSpec(**spec_out) - elif isinstance(spec, gym.spaces.dict.Dict): + elif isinstance(spec, gym_spaces.dict.Dict): return _gym_to_torchrl_spec_transform( spec.spaces, device=device, categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, ) else: raise NotImplementedError( @@ -266,6 +325,8 @@ def _gym_to_torchrl_spec_transform( def _get_envs(to_dict=False) -> List: + if not _has_gym: + raise ImportError("Gym(nasium) could not be found in your virtual environment.") envs = _get_gym_envs() envs = list(envs) envs = sorted(envs) @@ -325,7 +386,52 @@ class PixelObservationWrapper: return False -class GymWrapper(GymLikeEnv): +class _AsyncMeta(_EnvPostInit): + def __call__(cls, *args, **kwargs): + instance: GymWrapper = super().__call__(*args, **kwargs) + + # before gym 0.22, there was no final_observation + if instance._is_batched: + gym_backend = instance.get_library_name(instance._env) + from torchrl.envs.transforms.transforms import ( + TransformedEnv, + VecGymEnvTransform, + ) + + if _has_sb3: + from stable_baselines3.common.vec_env.base_vec_env import VecEnv + + if isinstance(instance._env, VecEnv): + backend = "sb3" + else: + backend = "gym" + else: + backend = "gym" + + # we need 3 checks: the backend is not sb3 (if so, gymnasium is used), + # it is gym and not gymnasium and the version is before 0.22.0 + add_info_dict = True + if backend == "gym" and gym_backend == "gym": # check gym against gymnasium + import gym + + if version.parse(gym.__version__) < version.parse("0.22.0"): + warn( + "A batched gym environment is being wrapped in a GymWrapper with gym version < 0.22. " + "This implies that the next-observation is wrongly tracked (as the batched environment auto-resets " + "and discards the true next observation to return the result of the step). " + "This isn't compatible with TorchRL API and should be used with caution.", + category=UserWarning, + ) + add_info_dict = False + if add_info_dict: + instance.set_info_dict_reader( + terminal_obs_reader(instance.observation_spec, backend=backend) + ) + return TransformedEnv(instance, VecGymEnvTransform()) + return instance + + +class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): """OpenAI Gym environment wrapper. Examples: @@ -341,23 +447,35 @@ class GymWrapper(GymLikeEnv): libname = "gym" @staticmethod - def get_library_name(env): - # try gym + def get_library_name(env) -> str: + """Given a gym environment, returns the backend name (either gym or gymnasium). + + This can be used to set the appropriate backend when needed: + + Examples: + >>> env = gymnasium.make("Pendulum-v1") + >>> with set_gym_backend(env): + ... env = GymWrapper(env) + + :class:`~GymWrapper` and similar use this method to set their method + to the right backend during instantiation. + + """ try: import gym if isinstance(env.action_space, gym.spaces.space.Space): - return gym + return "gym" except ImportError: pass try: import gymnasium if isinstance(env.action_space, gymnasium.spaces.space.Space): - return gymnasium + return "gymnasium" except ImportError: pass - raise RuntimeError( + raise ImportError( f"Could not find the library of env {env}. Please file an issue on torchrl github repo." ) @@ -367,10 +485,64 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): self._seed_calls_reset = None self._categorical_action_encoding = categorical_action_encoding if "env" in kwargs: - with set_gym_backend(self.get_library_name(kwargs["env"])): + if "EnvCompatibility" in str( + kwargs["env"] + ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env + raise ValueError( + "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " + "If this feature is needed, detail your use case in an issue of " + "https://github.com/pytorch/rl/issues." + ) + libname = self.get_library_name(kwargs["env"]) + with set_gym_backend(libname): super().__init__(**kwargs) else: super().__init__(**kwargs) + self._post_init() + + def _post_init(self): + # writes the functions that are gym-version specific to the instance + # once and for all. This is aimed at avoiding the need of decorating code + # with set_gym_backend + allowing for parallel execution (which would + # be troublesome when both an old version of gym and recent gymnasium + # are present within the same virtual env). + # + # These calls seemingly do nothing but they actually get rid of the @implement_for decorator. + # We execute them within the set_gym_backend context manager to make sure we get + # the right implementation. + # + # This method is executed by the metaclass of GymWrapper. + with set_gym_backend(self.get_library_name(self._env)): + self._reset_output_transform = self._reset_output_transform + self._output_transform = self._output_transform + + @property + def _is_batched(self): + if _has_sb3: + from stable_baselines3.common.vec_env.base_vec_env import VecEnv + + tuple_of_classes = (VecEnv,) + else: + tuple_of_classes = () + return isinstance( + self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,) + ) + + @implement_for("gym", None, "0.27") + def _get_batch_size(self, env): + if hasattr(env, "num_envs"): + batch_size = torch.Size([env.num_envs, *self.batch_size]) + else: + batch_size = self.batch_size + return batch_size + + @implement_for("gymnasium", "0.27", None) # gymnasium wants the unwrapped env + def _get_batch_size(self, env): # noqa: F811 + if hasattr(env, "num_envs"): + batch_size = torch.Size([env.unwrapped.num_envs, *self.batch_size]) + else: + batch_size = self.batch_size + return batch_size def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: @@ -385,6 +557,8 @@ def _build_env( from_pixels: bool = False, pixels_only: bool = False, ) -> "gym.core.Env": # noqa: F821 + self.batch_size = self._get_batch_size(env) + env_from_pixels = _is_from_pixels(env) from_pixels = from_pixels or env_from_pixels self.from_pixels = from_pixels @@ -404,6 +578,16 @@ def _build_env( env = self._build_gym_env(env, pixels_only) return env + def read_action(self, action): + action = super().read_action(action) + if ( + isinstance(self.action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec)) + and action.size == 1 + ): + # some envs require an integer for indexing + action = int(action) + return action + @implement_for("gym", None, "0.19.0") def _build_gym_env(self, env, pixels_only): # noqa: F811 from .utils import GymPixelObservationWrapper as PixelObservationWrapper @@ -464,8 +648,10 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811 return LegacyPixelObservationWrapper(env, pixels_only=pixels_only) @_classproperty - def available_envs(cls) -> List[str]: - return _get_envs() + def available_envs(cls): + if not _has_gym: + return + yield from _get_envs() @property def lib(self) -> ModuleType: @@ -482,7 +668,12 @@ def _set_seed(self, seed: int) -> int: # noqa: F811 return seed - @implement_for("gym", None, "0.19.0") + @implement_for("gym", None, "0.15.0") + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + self._seed_calls_reset = False + self._env.seed(seed) + + @implement_for("gym", "0.15.0", "0.19.0") def _set_seed_initial(self, seed: int) -> None: # noqa: F811 self._seed_calls_reset = False self._env.seed(seed=seed) @@ -513,8 +704,8 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811 self._seed_calls_reset = False self._env.seed(seed=seed) - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 - self.action_spec = _gym_to_torchrl_spec_transform( + def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + action_spec = _gym_to_torchrl_spec_transform( env.action_space, device=self.device, categorical_action_encoding=self._categorical_action_encoding, @@ -526,21 +717,174 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 ) if not isinstance(observation_spec, CompositeSpec): if self.from_pixels: - observation_spec = CompositeSpec(pixels=observation_spec) + observation_spec = CompositeSpec( + pixels=observation_spec, shape=self.batch_size + ) else: - observation_spec = CompositeSpec(observation=observation_spec) - self.observation_spec = observation_spec + observation_spec = CompositeSpec( + observation=observation_spec, shape=self.batch_size + ) + elif observation_spec.shape[: len(self.batch_size)] != self.batch_size: + observation_spec.shape = self.batch_size + if hasattr(env, "reward_space") and env.reward_space is not None: - self.reward_spec = _gym_to_torchrl_spec_transform( + reward_spec = _gym_to_torchrl_spec_transform( env.reward_space, device=self.device, categorical_action_encoding=self._categorical_action_encoding, ) else: - self.reward_spec = UnboundedContinuousTensorSpec( + reward_spec = UnboundedContinuousTensorSpec( shape=[1], device=self.device, ) + if batch_size is not None: + action_spec = action_spec.expand(*batch_size, *action_spec.shape) + reward_spec = reward_spec.expand(*batch_size, *reward_spec.shape) + observation_spec = observation_spec.expand( + *batch_size, *observation_spec.shape + ) + self.done_spec = self._make_done_spec() + self.action_spec = action_spec + if reward_spec.shape[: len(self.batch_size)] != self.batch_size: + self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape) + else: + self.reward_spec = reward_spec + self.observation_spec = observation_spec + + @implement_for("gym", None, "0.26") + def _make_done_spec(self): # noqa: F811 + return CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + "terminated": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + "truncated": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + }, + shape=self.batch_size, + ) + + @implement_for("gym", "0.26", None) + def _make_done_spec(self): # noqa: F811 + return CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + "terminated": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + "truncated": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + }, + shape=self.batch_size, + ) + + @implement_for("gymnasium", "0.27", None) + def _make_done_spec(self): # noqa: F811 + return CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + "terminated": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + "truncated": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) + ), + }, + shape=self.batch_size, + ) + + @implement_for("gym", None, "0.26") + def _reset_output_transform(self, reset_data): # noqa: F811 + return reset_data, None + + @implement_for("gym", "0.26", None) + def _reset_output_transform(self, reset_data): # noqa: F811 + return reset_data + + @implement_for("gymnasium", "0.27", None) + def _reset_output_transform(self, reset_data): # noqa: F811 + return reset_data + + @implement_for("gym", None, "0.24") + def _output_transform(self, step_outputs_tuple): # noqa: F811 + observations, reward, done, info = step_outputs_tuple + if self._is_batched: + # info needs to be flipped + info = _flip_info_tuple(info) + # The variable naming follows torchrl's convention here. + # A done is interpreted the union of terminated and truncated. + # (as in earlier versions of gym). + truncated = info.pop("TimeLimit.truncated", False) + if not isinstance(done, bool) and isinstance(truncated, bool): + # if bool is an array, make truncated an array + truncated = [truncated] * len(done) + truncated = np.array(truncated) + elif not isinstance(truncated, bool): + # make sure it's a boolean np.array + truncated = np.array(truncated, dtype=np.dtype("bool")) + terminated = done & ~truncated + if not isinstance(terminated, np.ndarray): + # if it's not a ndarray, we must return bool + # since it's not a bool, we make it so + terminated = bool(terminated) + return (observations, reward, terminated, truncated, done, info) + + @implement_for("gym", "0.24", "0.26") + def _output_transform(self, step_outputs_tuple): # noqa: F811 + observations, reward, done, info = step_outputs_tuple + # The variable naming follows torchrl's convention here. + # A done is interpreted the union of terminated and truncated. + # (as in earlier versions of gym). + truncated = info.pop("TimeLimit.truncated", False) + if not isinstance(done, bool) and isinstance(truncated, bool): + # if bool is an array, make truncated an array + truncated = [truncated] * len(done) + truncated = np.array(truncated) + elif not isinstance(truncated, bool): + # make sure it's a boolean np.array + truncated = np.array(truncated, dtype=np.dtype("bool")) + terminated = done & ~truncated + if not isinstance(terminated, np.ndarray): + # if it's not a ndarray, we must return bool + # since it's not a bool, we make it so + terminated = bool(terminated) + return (observations, reward, terminated, truncated, done, info) + + @implement_for("gym", "0.26", None) + def _output_transform(self, step_outputs_tuple): # noqa: F811 + # The variable naming follows torchrl's convention here. + observations, reward, terminated, truncated, info = step_outputs_tuple + return ( + observations, + reward, + terminated, + truncated, + terminated | truncated, + info, + ) + + @implement_for("gymnasium", "0.27", None) + def _output_transform(self, step_outputs_tuple): # noqa: F811 + # The variable naming follows torchrl's convention here. + observations, reward, terminated, truncated, info = step_outputs_tuple + return ( + observations, + reward, + terminated, + truncated, + terminated | truncated, + info, + ) def _init_env(self): self.reset() @@ -557,14 +901,30 @@ def rebuild_with_kwargs(self, **new_kwargs): @property def info_dict_reader(self): - if self._info_dict_reader is None: - self._info_dict_reader = default_info_dict_reader() + if not self._info_dict_reader: + self._info_dict_reader.append(default_info_dict_reader()) return self._info_dict_reader @info_dict_reader.setter def info_dict_reader(self, value: callable): self._info_dict_reader = value + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + if self._is_batched: + # batched (aka 'vectorized') env reset is a bit special: envs are + # automatically reset. What we do here is just to check if _reset + # is present. If it is not, we just reset. Otherwise we just skip. + if tensordict is None: + return super()._reset(tensordict) + reset = tensordict.get("_reset", None) + if reset is None: + return super()._reset(tensordict) + elif reset is not None: + return tensordict.clone(False) + return super()._reset(tensordict, **kwargs) + ACCEPTED_TYPE_ERRORS = { "render_mode": "__init__() got an unexpected keyword argument 'render_mode'", @@ -610,6 +970,9 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) + def _async_env(self, *args, **kwargs): + return gym_backend("vector").AsyncVectorEnv(*args, **kwargs) + def _build_env( self, env_name: str, @@ -621,13 +984,10 @@ def _build_env( f"Consider downloading and installing gym from" f" {self.git_url}" ) - from_pixels = kwargs.get("from_pixels", False) + from_pixels = kwargs.pop("from_pixels", False) self._set_gym_default(kwargs, from_pixels) - if "from_pixels" in kwargs: - del kwargs["from_pixels"] - pixels_only = kwargs.get("pixels_only", True) - if "pixels_only" in kwargs: - del kwargs["pixels_only"] + pixels_only = kwargs.pop("pixels_only", True) + num_envs = kwargs.pop("num_envs", 0) made_env = False kwargs["frameskip"] = self.frame_skip self.wrapper_frame_skip = 1 @@ -654,7 +1014,18 @@ def _build_env( kwargs.pop("render_mode") else: raise err - return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + if num_envs > 0: + try: + env = self._async_env([CloudpickleWrapper(lambda: env)] * num_envs) + except RuntimeError: + # It would fail if the environment is not pickable. In that case, + # delegating environment instantiation to each subprocess as a fallback. + env = self._async_env( + [lambda: self.lib.make(env_name, **kwargs)] * num_envs + ) + self.batch_size = torch.Size([num_envs, *self.batch_size]) + return env @implement_for("gym", None, "0.25.1") def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 @@ -728,3 +1099,100 @@ def lib(self) -> ModuleType: raise ImportError("MO-gymnasium not found, check installation") from err _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs) + + +class terminal_obs_reader(BaseInfoDictReader): + """Terminal observation reader for 'vectorized' gym environments. + + When running envs in parallel, Gym(nasium) writes the result of the true call + to `step` in `"final_observation"` entry within the `info` dictionary. + + This breaks the natural flow and makes single-processed and multiprocessed envs + incompatible. + + This class reads the info obs, removes the `"final_observation"` from + the env and writes its content in the data. + + Next, a :class:`torchrl.envs.VecGymEnvTransform` transform will reorganise the + data by caching the result of the (implicit) reset and swap the true next + observation with the reset one. At reset time, the true reset data will be + replaced. + + Args: + observation_spec (CompositeSpec): The observation spec of the gym env. + backend (str, optional): the backend of the env. One of `"sb3"` for + stable-baselines3 or `"gym"` for gym/gymnasium. + + .. note:: In general, this class should not be handled directly. It is + created whenever a vectorized environment is placed within a :class:`GymWrapper`. + + """ + + backend_key = { + "sb3": "terminal_observation", + "gym": "final_observation", + } + + def __init__(self, observation_spec: CompositeSpec, backend, name="final"): + self.name = name + self._info_spec = CompositeSpec( + {(self.name, key): item.clone() for key, item in observation_spec.items()}, + shape=observation_spec.shape, + ) + self.backend = backend + + @property + def info_spec(self): + return self._info_spec + + def _read_obs(self, obs, key, tensor, index): + if obs is None: + return + if isinstance(obs, np.ndarray): + # Simplest case: there is one observation, + # presented as a np.ndarray. The key should be pixels or observation. + # We just write that value at its location in the tensor + tensor[index] = torch.as_tensor(obs, device=tensor.device) + elif isinstance(obs, dict): + if key not in obs: + raise KeyError( + f"The observation {key} could not be found in the final observation dict." + ) + subobs = obs[key] + if subobs is not None: + # if the obs is a dict, we expect that the key points also to + # a value in the obs. We retrieve this value and write it in the + # tensor + tensor[index] = torch.as_tensor(subobs, device=tensor.device) + + elif isinstance(obs, (list, tuple)): + # tuples are stacked along the first dimension when passing gym spaces + # to torchrl specs. As such, we can simply stack the tuple and set it + # at the relevant index (assuming stacking can be achieved) + tensor[index] = torch.as_tensor(obs, device=tensor.device) + else: + raise NotImplementedError( + f"Observations of type {type(obs)} are not supported yet." + ) + + def __call__(self, info_dict, tensordict): + terminal_obs = info_dict.get(self.backend_key[self.backend], None) + for key, item in self.info_spec.items(True, True): + final_obs = item.zero() + if terminal_obs is not None: + for i, obs in enumerate(terminal_obs): + self._read_obs(obs, key[-1], final_obs, index=i) + tensordict.set(key, final_obs) + return tensordict + + +def _flip_info_tuple(info: Tuple[Dict]) -> Dict[str, tuple]: + # In Gym < 0.24, batched envs returned tuples of dict, and not dict of tuples. + # We patch this by flipping the tuple -> dict order. + info_example = set(info[0]) + for item in info[1:]: + info_example = info_example.union(item) + result = {} + for key in info_example: + result[key] = tuple(_info.get(key, None) for _info in info) + return result diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 6074ca42207..52f12140a51 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -3,23 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import functools +import importlib.util import torch -from torchrl.data import DEVICE_TYPING -from torchrl.envs import EnvBase +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs.common import EnvBase from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import classproperty +from torchrl.envs.utils import _classproperty -IMPORT_ERR = None -try: - import habitat - import habitat.gym # noqa - - _has_habitat = True -except ImportError as err: - _has_habitat = False - IMPORT_ERR = err +_has_habitat = importlib.util.find_spec("habitat") is not None def _wrap_import_error(fun): @@ -31,7 +24,7 @@ def new_fun(*args, **kwargs): "it or solving the import bugs (see attached error message). " "Refer to TorchRL's knowledge base in the documentation to " "debug habitat installation." - ) from IMPORT_ERR + ) return fun(*args, **kwargs) return new_fun @@ -55,14 +48,19 @@ class HabitatEnv(GymEnv): @_wrap_import_error @set_gym_backend("gym") def __init__(self, env_name, **kwargs): + import habitat # noqa + import habitat.gym # noqa + device_num = torch.device(kwargs.pop("device", 0)).index kwargs["override_options"] = [ f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}", ] super().__init__(env_name=env_name, **kwargs) - @classproperty + @_classproperty def available_envs(cls): + if not _has_habitat: + return yield from _get_available_envs() def _build_gym_env(self, env, pixels_only): diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py new file mode 100644 index 00000000000..9206d23d09a --- /dev/null +++ b/torchrl/envs/libs/isaacgym.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util + +import itertools +import warnings +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch + +from tensordict import TensorDictBase +from torchrl.envs.libs.gym import GymWrapper +from torchrl.envs.utils import _classproperty, make_composite_from_td + +_has_isaac = importlib.util.find_spec("isaacgym") is not None + + +class IsaacGymWrapper(GymWrapper): + """Wrapper for IsaacGymEnvs environments. + + The original library can be found `here `_ + and is based on IsaacGym which can be downloaded `through NVIDIA's webpage _`. + + .. note:: IsaacGym environments cannot be executed consecutively, ie. instantiating one + environment after another (even if it has been cleared) will cause + CUDA memory issues. We recommend creating one environment per process only. + If you need more than one environment, the best way to achieve that is + to spawn them across processes. + + .. note:: IsaacGym works on CUDA devices by essence. Make sure your machine + has GPUs available and the required setup for IsaacGym (eg, Ubuntu 20.04). + + """ + + @property + def lib(self): + import isaacgym + + return isaacgym + + def __init__( + self, env: "isaacgymenvs.tasks.base.vec_task.Env", **kwargs # noqa: F821 + ): + warnings.warn( + "IsaacGym environment support is an experimental feature that may change in the future." + ) + num_envs = env.num_envs + super().__init__( + env, torch.device(env.device), batch_size=torch.Size([num_envs]), **kwargs + ) + if not hasattr(self, "task"): + # by convention in IsaacGymEnvs + self.task = env.__name__ + + def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + super()._make_specs(env, batch_size=self.batch_size) + self.full_done_spec = { + key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True) + } + self.observation_spec["obs"] = self.observation_spec["observation"] + del self.observation_spec["observation"] + + data = self.rollout(3).get("next")[..., 0] + del data[self.reward_key] + for done_key in self.done_keys: + try: + del data[done_key] + except KeyError: + continue + specs = make_composite_from_td(data) + + obs_spec = self.observation_spec + obs_spec.unlock_() + obs_spec.update(specs) + obs_spec.lock_() + self.__dict__["full_observation_spec"] = obs_spec + + @classmethod + def _make_envs(cls, *, task, num_envs, device, seed=None, headless=True, **kwargs): + import isaacgym # noqa + import isaacgymenvs # noqa + + envs = isaacgymenvs.make( + seed=seed, + task=task, + num_envs=num_envs, + sim_device=str(device), + rl_device=str(device), + headless=headless, + **kwargs, + ) + return envs + + def _set_seed(self, seed: int) -> int: + # as of #665c32170d84b4be66722eea405a1e08b6e7f761 the seed points nowhere in gym.make for IsaacGymEnvs + return seed + + def read_action(self, action): + """Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment. + + Args: + action (Tensor or TensorDict): an action to be taken in the environment + + Returns: an action in a format compatible with the contained environment. + + """ + return action + + def read_done( + self, + terminated: bool = None, + truncated: bool | None = None, + done: bool | None = None, + ) -> Tuple[bool, bool, bool]: + if terminated is not None: + terminated = terminated.bool() + if truncated is not None: + truncated = truncated.bool() + if done is not None: + done = done.bool() + return terminated, truncated, done, done.any() + + def read_reward(self, total_reward, step_reward): + """Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two. + + Args: + total_reward (torch.Tensor or TensorDict): total reward so far in the step + step_reward (reward in the format provided by the inner env): reward of this particular step + + """ + return total_reward + step_reward + + def read_obs( + self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] + ) -> Dict[str, Any]: + """Reads an observation from the environment and returns an observation compatible with the output TensorDict. + + Args: + observations (observation under a format dictated by the inner env): observation to be read. + + """ + if isinstance(observations, dict): + if "state" in observations and "observation" not in observations: + # we rename "state" in "observation" as "observation" is the conventional name + # for single observation in torchrl. + # naming it 'state' will result in envs that have a different name for the state vector + # when queried with and without pixels + observations["observation"] = observations.pop("state") + if not isinstance(observations, (TensorDictBase, dict)): + (key,) = itertools.islice(self.observation_spec.keys(True, True), 1) + observations = {key: observations} + return observations + + +class IsaacGymEnv(IsaacGymWrapper): + """A TorchRL Env interface for IsaacGym environments. + + See :class:`~.IsaacGymWrapper` for more information. + + Examples: + >>> env = IsaacGymEnv(task="Ant", num_envs=2000, device="cuda:0") + >>> rollout = env.rollout(3) + >>> assert env.batch_size == (2000,) + + """ + + @_classproperty + def available_envs(cls): + if not _has_isaac: + return + + import isaacgymenvs # noqa + + yield from isaacgymenvs.tasks.isaacgym_task_map.keys() + + def __init__(self, task=None, *, env=None, num_envs, device, **kwargs): + if env is not None and task is not None: + raise RuntimeError("Cannot provide both `task` and `env` arguments.") + elif env is not None: + task = env + envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs) + self.task = task + super().__init__(envs, **kwargs) diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 5eac0a42c9e..10889a169cd 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -2,14 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import dataclasses +import importlib.util from typing import Union -import jax +# import jax import numpy as np import torch -from jax import dlpack as jax_dlpack, numpy as jnp + +# from jax import dlpack as jax_dlpack, numpy as jnp from tensordict.tensordict import make_tensordict, TensorDictBase from torch.utils import dlpack as torch_dlpack from torchrl.data.tensor_specs import ( @@ -20,13 +21,19 @@ ) from torchrl.data.utils import numpy_to_torch_dtype_dict +_has_jax = importlib.util.find_spec("jax") is not None + def _tree_reshape(x, batch_size: torch.Size): + import jax + shape, n = batch_size, 1 return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) def _tree_flatten(x, batch_size: torch.Size): + import jax + shape, n = (batch_size.numel(),), len(batch_size) return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) @@ -38,7 +45,11 @@ def _tree_flatten(x, batch_size: torch.Size): } -def _ndarray_to_tensor(value: Union[jnp.ndarray, np.ndarray]) -> torch.Tensor: +def _ndarray_to_tensor( + value: Union["jnp.ndarray", np.ndarray] # noqa: F821 +) -> torch.Tensor: + from jax import dlpack as jax_dlpack, numpy as jnp + # JAX arrays generated by jax.vmap would have Numpy dtypes. if value.dtype in _dtype_conversion: value = value.view(_dtype_conversion[value.dtype]) @@ -53,7 +64,9 @@ def _ndarray_to_tensor(value: Union[jnp.ndarray, np.ndarray]) -> torch.Tensor: return out.to(numpy_to_torch_dtype_dict[value.dtype]) -def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: +def _tensor_to_ndarray(value: torch.Tensor) -> "jnp.ndarray": # noqa: F821 + from jax import dlpack as jax_dlpack + return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value)) @@ -75,6 +88,8 @@ def _get_object_fields(obj) -> dict: def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase: """Converts a namedtuple or a dataclass to a TensorDict.""" + from jax import numpy as jnp + t = {} _fields = _get_object_fields(obj) for name, value in _fields.items(): @@ -94,6 +109,8 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase: def _tensordict_to_object(tensordict: TensorDictBase, object_example): """Converts a TensorDict to a namedtuple or a dataclass.""" + from jax import dlpack as jax_dlpack + t = {} _fields = _get_object_fields(object_example) for name, example in _fields.items(): diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 690b81f2c47..8cffd1ec97f 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -2,12 +2,16 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import torch from tensordict.tensordict import TensorDict, TensorDictBase +from torchrl.envs.utils import _classproperty + +_has_jumanji = importlib.util.find_spec("jumanji") is not None from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -20,31 +24,23 @@ UnboundedDiscreteTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict -from torchrl.envs import GymLikeEnv - -try: - import jax - import jumanji - from jax import numpy as jnp - from torchrl.envs.libs.jax_utils import ( - _extract_spec, - _ndarray_to_tensor, - _object_to_tensordict, - _tensordict_to_object, - _tree_flatten, - _tree_reshape, - ) - - _has_jumanji = True - IMPORT_ERR = "" -except ImportError as err: - _has_jumanji = False - IMPORT_ERR = str(err) +from torchrl.envs.gym_like import GymLikeEnv + +from torchrl.envs.libs.jax_utils import ( + _extract_spec, + _ndarray_to_tensor, + _object_to_tensordict, + _tensordict_to_object, + _tree_flatten, + _tree_reshape, +) def _get_envs(): if not _has_jumanji: - return [] + raise ImportError("Jumanji is not installed in your virtual environment.") + import jumanji + return jumanji.registered_environments() @@ -54,6 +50,8 @@ def _jumanji_to_torchrl_spec_transform( device: DEVICE_TYPING = None, categorical_action_encoding: bool = True, ) -> TensorSpec: + import jumanji + if isinstance(spec, jumanji.specs.DiscreteArray): action_space_cls = ( DiscreteTensorSpec @@ -69,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform( dtype = numpy_to_torch_dtype_dict[spec.dtype] return BoundedTensorSpec( shape=shape, - minimum=np.asarray(spec.minimum), - maximum=np.asarray(spec.maximum), + low=np.asarray(spec.minimum), + high=np.asarray(spec.maximum), dtype=dtype, device=device, ) @@ -132,18 +130,25 @@ class JumanjiWrapper(GymLikeEnv): """ git_url = "https://github.com/instadeepai/jumanji" - available_envs = _get_envs() libname = "jumanji" + @_classproperty + def available_envs(cls): + if not _has_jumanji: + return + yield from _get_envs() + @property def lib(self): + import jumanji + return jumanji - def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): + def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): # noqa: F821 if not _has_jumanji: raise ImportError( "jumanji is not installed or importing it failed. Consider checking your installation." - ) from IMPORT_ERR + ) if env is not None: kwargs["env"] = env super().__init__(**kwargs) @@ -166,6 +171,9 @@ def _build_env( return env def _make_state_example(self, env): + import jax + from jax import numpy as jnp + key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state, _ = jax.vmap(env.reset)(jnp.stack(keys)) @@ -173,6 +181,8 @@ def _make_state_example(self, env): return state def _make_state_spec(self, env) -> TensorSpec: + import jax + key = jax.random.PRNGKey(0) state, _ = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) @@ -187,6 +197,8 @@ def _make_action_spec(self, env) -> TensorSpec: return action_spec def _make_observation_spec(self, env) -> TensorSpec: + jumanji = self.lib + spec = env.observation_spec() new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device) if isinstance(spec, jumanji.specs.Array): @@ -222,6 +234,7 @@ def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 self._state_example = self._make_state_example(env) def _check_kwargs(self, kwargs: Dict): + jumanji = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -231,7 +244,22 @@ def _check_kwargs(self, kwargs: Dict): def _init_env(self): pass + @property + def key(self): + key = getattr(self, "_key", None) + if key is None: + raise RuntimeError( + "the env.key attribute wasn't found. Make sure to call `env.set_seed(seed)` before any interaction." + ) + return key + + @key.setter + def key(self, value): + self._key = value + def _set_seed(self, seed): + import jax + if seed is None: raise Exception("Jumanji requires an integer seed.") self.key = jax.random.PRNGKey(seed) @@ -241,6 +269,8 @@ def read_state(self, state): return self.state_spec["state"].encode(state_dict) def read_obs(self, obs): + from jax import numpy as jnp + if isinstance(obs, (list, jnp.ndarray, np.ndarray)): obs_dict = _ndarray_to_tensor(obs).to(self.device) else: @@ -248,11 +278,11 @@ def read_obs(self, obs): return super().read_obs(obs_dict) def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + import jax # prepare inputs state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = self.read_action(tensordict.get("action")) - reward = self.reward_spec.zero() # flatten batch size into vector state = _tree_flatten(state, self.batch_size) @@ -268,7 +298,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # collect outputs state_dict = self.read_state(state) obs_dict = self.read_obs(timestep.observation) - reward = self.read_reward(reward, np.asarray(timestep.reward)) + reward = self.read_reward(np.asarray(timestep.reward)) done = timestep.step_type == self.lib.types.StepType.LAST done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) @@ -280,13 +310,17 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ) tensordict_out.set("reward", reward) tensordict_out.set("done", done) + tensordict_out.set("terminated", done) + # tensordict_out.set("terminated", done) tensordict_out["state"] = state_dict - return tensordict_out.select().set("next", tensordict_out) + return tensordict_out def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: + import jax + from jax import numpy as jnp # generate random keys self.key, *keys = jax.random.split(self.key, self.numel() + 1) @@ -301,7 +335,7 @@ def _reset( # collect outputs state_dict = self.read_state(state) obs_dict = self.read_obs(timestep.observation) - done = self.done_spec.zero() + done_td = self.full_done_spec.zero() # build results tensordict_out = TensorDict( @@ -309,11 +343,17 @@ def _reset( batch_size=self.batch_size, device=self.device, ) - tensordict_out.set("done", done) + tensordict_out.update(done_td) tensordict_out["state"] = state_dict return tensordict_out + def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: + ... + + def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: + ... + class JumanjiEnv(JumanjiWrapper): """Jumanji environment wrapper. @@ -333,13 +373,13 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "jumanji.env.Environment": + ) -> "jumanji.env.Environment": # noqa: F821 if not _has_jumanji: - raise RuntimeError( + raise ImportError( f"jumanji not found, unable to create {env_name}. " f"Consider installing jumanji. More info:" f" {self.git_url}." - ) from IMPORT_ERR + ) from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) if kwargs: diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 8cbc9dfb5b4..25b5dfbad30 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -14,7 +14,8 @@ ) from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement -from torchrl.envs import Compose, DoubleToFloat, EnvBase, RenameTransform +from torchrl.envs.common import EnvBase +from torchrl.envs.transforms import Compose, DoubleToFloat, RenameTransform def _make_composite_from_td(td): @@ -127,7 +128,7 @@ def _step( self.batch_size, device=self.device, ) - return td.select().set("next", td) + return td def _set_seed(self, seed): self.rng = torch.random.manual_seed(seed) diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py new file mode 100644 index 00000000000..1f8b02fd1f6 --- /dev/null +++ b/torchrl/envs/libs/pettingzoo.py @@ -0,0 +1,895 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +from typing import Dict, List, Optional, Tuple, Union + +import torch +from tensordict.tensordict import TensorDictBase + +from torchrl.data import ( + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend +from torchrl.envs.utils import ( + _classproperty, + _replace_last, + check_marl_grouping, + MarlGroupMapType, +) + +_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None + + +def _get_envs(): + if not _has_pettingzoo: + raise ImportError("PettingZoo is not installed in your virtual environment.") + from pettingzoo.utils.all_modules import all_environments + + return list(all_environments.keys()) + + +class PettingZooWrapper(_EnvWrapper): + """PettingZoo environment wrapper. + + To install petting zoo follow the guide `here __`. + + This class is a general torchrl wrapper for all PettingZoo environments. + It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``. + + Let's see how more in details: + + In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step. + If the number of agents during the task varies, please set ``use_mask=True``. + ``"mask"`` will be provided + as an output in each group and should be used to mask out dead agents. + The environment will be reset as soon as one agent is done. + + In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act. + For this reason, it is compulsory to set ``use_mask=True`` for this type of environment. + ``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents. + The environment will be reset only when all agents are done. + + If there are any unavailable actions for an agent, + the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"`` + for each group to reflect the latest available actions. This should be passed to a masked distribution during + training. + + As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment. + You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same + neural network. You can split agents in different groups where they are heterogenous or should be processed by + different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time. + + By default, agents in pettingzoo will be grouped by name. + For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agent: TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]))}, + adversary: TensorDict( + fields={ + action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False), + terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([1]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]} + + Otherwise, a group map can be specified or selected from some premade options. + See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should + have its own tensordict (similar to the pettingzoo parallel API). + + Grouping is useful for leveraging vectorisation among agents whose data goes through the same + neural network. + + Args: + env (``pettingzoo.utils.env.ParallelEnv`` or ``pettingzoo.utils.env.AECEnv``): the pettingzoo environment to wrap. + return_state (bool, optional): whether to return the global state from pettingzoo + (not available in all environments). Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for + input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified + or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + use_mask (bool, optional): whether the environment should output a ``"mask"``. This is compulsory in + wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used + for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``. + categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform + them to categorical or one-hot. + seed (int, optional): the seed. Defaults to ``None``. + + Examples: + >>> # Parallel env + >>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper + >>> from pettingzoo.butterfly import pistonball_v6 + >>> kwargs = {"n_pistons": 21, "continuous": True} + >>> env = PettingZooWrapper( + ... env=pistonball_v6.parallel_env(**kwargs), + ... return_state=True, + ... group_map=None, # Use default for parallel (all pistons grouped together) + ... ) + >>> print(env.group_map) + ... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']} + >>> env.rollout(10) + >>> # AEC env + >>> from pettingzoo.classic import tictactoe_v3 + >>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper + >>> from torchrl.envs.utils import MarlGroupMapType + >>> env = PettingZooWrapper( + ... env=tictactoe_v3.env(), + ... use_mask=True, # Must use it since one player plays at a time + ... group_map=None # # Use default for AEC (one group per player) + ... ) + >>> print(env.group_map) + ... {'player_1': ['player_1'], 'player_2': ['player_2']} + >>> env.rollout(10) + """ + + git_url = "https://github.com/Farama-Foundation/PettingZoo" + libname = "pettingzoo" + + @_classproperty + def available_envs(cls): + if not _has_pettingzoo: + return + yield from _get_envs() + + def __init__( + self, + env: Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ] = None, + return_state: Optional[bool] = False, + group_map: Optional[Union[MarlGroupMapType, Dict[str, List[str]]]] = None, + use_mask: bool = False, + categorical_actions: bool = True, + seed: Optional[int] = None, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + + self.group_map = group_map + self.return_state = return_state + self.seed = seed + self.use_mask = use_mask + self.categorical_actions = categorical_actions + + super().__init__(**kwargs, allow_done_after_reset=True) + + def _get_default_group_map(self, agent_names: List[str]): + # This function performs the default grouping in pettingzoo + if not self.parallel: + # In AEC envs we will have one group per agent by default + group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names) + else: + # In parallel envs, by default + # Agents with names "str_int" will be grouped in group name "str" + group_map = {} + for agent_name in agent_names: + # See if the agent follows the convention "name_int" + follows_convention = True + agent_name_split = agent_name.split("_") + if len(agent_name_split) == 1: + follows_convention = False + try: + int(agent_name_split[-1]) + except ValueError: + follows_convention = False + + # If not, just put it in a single group + if not follows_convention: + group_map[agent_name] = [agent_name] + # Otherwise, group it with other agents that follow the same convention + else: + group_name = "_".join(agent_name_split[:-1]) + if group_name in group_map: + group_map[group_name].append(agent_name) + else: + group_map[group_name] = [agent_name] + + return group_map + + @property + def lib(self): + import pettingzoo + + return pettingzoo + + def _build_env( + self, + env: Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ], + ): + import pettingzoo + + self.parallel = isinstance(env, pettingzoo.utils.env.ParallelEnv) + if not self.parallel and not self.use_mask: + raise ValueError("For AEC environments you need to set use_mask=True") + if len(self.batch_size): + raise RuntimeError( + f"PettingZoo does not support custom batch_size {self.batch_size}." + ) + + return env + + @set_gym_backend("gymnasium") + def _make_specs( + self, + env: Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ], + ) -> None: + + # Create and check group map + if self.group_map is None: + self.group_map = self._get_default_group_map(self.possible_agents) + elif isinstance(self.group_map, MarlGroupMapType): + self.group_map = self.group_map.get_group_map(self.possible_agents) + check_marl_grouping(self.group_map, self.possible_agents) + self.has_action_mask = {group: False for group in self.group_map.keys()} + + action_spec = CompositeSpec() + observation_spec = CompositeSpec() + reward_spec = CompositeSpec() + done_spec = CompositeSpec() + for group, agents in self.group_map.items(): + ( + group_observation_spec, + group_action_spec, + group_reward_spec, + group_done_spec, + ) = self._make_group_specs(group_name=group, agent_names=agents) + action_spec[group] = group_action_spec + observation_spec[group] = group_observation_spec + reward_spec[group] = group_reward_spec + done_spec[group] = group_done_spec + + self.action_spec = action_spec + self.observation_spec = observation_spec + self.reward_spec = reward_spec + self.done_spec = done_spec + + def _make_group_specs(self, group_name: str, agent_names: List[str]): + n_agents = len(agent_names) + action_specs = [] + observation_specs = [] + for agent in agent_names: + action_specs.append( + CompositeSpec( + { + "action": _gym_to_torchrl_spec_transform( + self.action_space(agent), + remap_state_to_observation=False, + categorical_action_encoding=self.categorical_actions, + device=self.device, + ) + }, + ) + ) + observation_specs.append( + CompositeSpec( + { + "observation": _gym_to_torchrl_spec_transform( + self.observation_space(agent), + remap_state_to_observation=False, + device=self.device, + ) + } + ) + ) + group_action_spec = torch.stack(action_specs, dim=0) + group_observation_spec = torch.stack(observation_specs, dim=0) + + # Sometimes the observation spec contains an action mask. + # Or sometimes the info spec contains an action mask. + # We uniform this by removing it from both places and optionally set it in a standard location. + group_observation_inner_spec = group_observation_spec["observation"] + if ( + isinstance(group_observation_inner_spec, CompositeSpec) + and "action_mask" in group_observation_inner_spec.keys() + ): + self.has_action_mask[group_name] = True + del group_observation_inner_spec["action_mask"] + group_observation_spec["action_mask"] = DiscreteTensorSpec( + n=2, + shape=group_action_spec["action"].shape + if not self.categorical_actions + else ( + *group_action_spec["action"].shape, + group_action_spec["action"].space.n, + ), + dtype=torch.bool, + device=self.device, + ) + + if self.use_mask: + group_observation_spec["mask"] = DiscreteTensorSpec( + n=2, + shape=torch.Size((n_agents,)), + dtype=torch.bool, + device=self.device, + ) + + group_reward_spec = CompositeSpec( + { + "reward": UnboundedContinuousTensorSpec( + shape=torch.Size((n_agents, 1)), + device=self.device, + dtype=torch.float32, + ) + }, + shape=torch.Size((n_agents,)), + ) + group_done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + n=2, + shape=torch.Size((n_agents, 1)), + dtype=torch.bool, + device=self.device, + ), + "terminated": DiscreteTensorSpec( + n=2, + shape=torch.Size((n_agents, 1)), + dtype=torch.bool, + device=self.device, + ), + "truncated": DiscreteTensorSpec( + n=2, + shape=torch.Size((n_agents, 1)), + dtype=torch.bool, + device=self.device, + ), + }, + shape=torch.Size((n_agents,)), + ) + return ( + group_observation_spec, + group_action_spec, + group_reward_spec, + group_done_spec, + ) + + def _check_kwargs(self, kwargs: Dict): + import pettingzoo + + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance( + env, (pettingzoo.utils.env.ParallelEnv, pettingzoo.utils.env.AECEnv) + ): + raise TypeError("env is not of type expected.") + + def _init_env(self) -> Optional[int]: + # Add info + if self.parallel: + _, info_dict = self._reset_parallel(seed=self.seed) + else: + _, info_dict = self._reset_aec(seed=self.seed) + + for group, agents in self.group_map.items(): + info_specs = [] + for agent in agents: + info_specs.append( + CompositeSpec( + { + "info": CompositeSpec( + { + key: UnboundedContinuousTensorSpec( + shape=torch.tensor(value).shape, + device=self.device, + ) + for key, value in info_dict[agent].items() + } + ) + }, + device=self.device, + ) + ) + info_specs = torch.stack(info_specs, dim=0) + if ("info", "action_mask") in info_specs.keys(True, True): + if not self.has_action_mask[group]: + self.has_action_mask[group] = True + group_action_spec = self.input_spec[ + "full_action_spec", group, "action" + ] + self.observation_spec[group]["action_mask"] = DiscreteTensorSpec( + n=2, + shape=group_action_spec.shape + if not self.categorical_actions + else (*group_action_spec.shape, group_action_spec.space.n), + dtype=torch.bool, + device=self.device, + ) + group_inner_info_spec = info_specs["info"] + del group_inner_info_spec["action_mask"] + + if len(info_specs["info"].keys()): + self.observation_spec[group].update(info_specs) + + if self.return_state: + try: + state_spec = _gym_to_torchrl_spec_transform( + self.state_space, + remap_state_to_observation=False, + device=self.device, + ) + except AttributeError: + state_example = torch.tensor(self.state(), device=self.device) + state_spec = UnboundedContinuousTensorSpec( + shape=state_example.shape, + dtype=state_example.dtype, + device=self.device, + ) + self.observation_spec["state"] = state_spec + + # Caching + self.cached_reset_output_zero = self.observation_spec.zero() + self.cached_reset_output_zero.update(self.output_spec["full_done_spec"].zero()) + + self.cached_step_output_zero = self.observation_spec.zero() + self.cached_step_output_zero.update(self.output_spec["full_reward_spec"].zero()) + self.cached_step_output_zero.update(self.output_spec["full_done_spec"].zero()) + + def _set_seed(self, seed: int): + self.seed = seed + self.reset(seed=self.seed) + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + + if self.parallel: + # This resets when any is done + observation_dict, info_dict = self._reset_parallel(**kwargs) + else: + # This resets when all are done + observation_dict, info_dict = self._reset_aec(tensordict, **kwargs) + + # We start with zeroed data and fill in the data for alive agents + tensordict_out = self.cached_reset_output_zero.clone() + # Update the "mask" for non-acting agents + self._update_agent_mask(tensordict_out) + # Update the "action_mask" for non-available actions + observation_dict, info_dict = self._update_action_mask( + tensordict_out, observation_dict, info_dict + ) + + # Now we get the data (obs and info) + for group, agent_names in self.group_map.items(): + group_observation = tensordict_out.get((group, "observation")) + group_info = tensordict_out.get((group, "info"), None) + + for index, agent in enumerate(agent_names): + group_observation[index] = self.observation_spec[group, "observation"][ + index + ].encode(observation_dict[agent]) + if group_info is not None: + agent_info_dict = info_dict[agent] + for agent_info, value in agent_info_dict.items(): + group_info.get(agent_info)[index] = torch.tensor( + value, device=self.device + ) + + return tensordict_out + + def _reset_aec(self, tensordict=None, **kwargs) -> Tuple[Dict, Dict]: + all_done = True + if tensordict is not None: + _resets = [] + for done_key in self.done_keys: + _reset_key = _replace_last(done_key, "_reset") + _reset = tensordict.get(_reset_key, default=None) + if _reset is None: + continue + _resets.append(_reset) + if len(_resets) < len(self.done_keys): + all_done = False + else: + for _reset in _resets: + if not _reset.all(): + all_done = False + break + + if all_done: + self._env.reset(**kwargs) + + observation_dict = { + agent: self._env.observe(agent) for agent in self.possible_agents + } + info_dict = self._env.infos + return observation_dict, info_dict + + def _reset_parallel(self, **kwargs) -> Tuple[Dict, Dict]: + return self._env.reset(**kwargs) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + + if self.parallel: + ( + observation_dict, + rewards_dict, + terminations_dict, + truncations_dict, + info_dict, + ) = self._step_parallel(tensordict) + else: + ( + observation_dict, + rewards_dict, + terminations_dict, + truncations_dict, + info_dict, + ) = self._step_aec(tensordict) + + # We start with zeroed data and fill in the data for alive agents + tensordict_out = self.cached_step_output_zero.clone() + # Update the "mask" for non-acting agents + self._update_agent_mask(tensordict_out) + # Update the "action_mask" for non-available actions + observation_dict, info_dict = self._update_action_mask( + tensordict_out, observation_dict, info_dict + ) + + # Now we get the data + for group, agent_names in self.group_map.items(): + group_observation = tensordict_out.get((group, "observation")) + group_reward = tensordict_out.get((group, "reward")) + group_done = tensordict_out.get((group, "done")) + group_terminated = tensordict_out.get((group, "terminated")) + group_truncated = tensordict_out.get((group, "truncated")) + group_info = tensordict_out.get((group, "info"), None) + + for index, agent in enumerate(agent_names): + if agent in observation_dict: # Live agents + group_observation[index] = self.observation_spec[ + group, "observation" + ][index].encode(observation_dict[agent]) + group_reward[index] = torch.tensor( + rewards_dict[agent], + device=self.device, + dtype=torch.float32, + ) + group_done[index] = torch.tensor( + terminations_dict[agent] or truncations_dict[agent], + device=self.device, + dtype=torch.bool, + ) + group_truncated[index] = torch.tensor( + truncations_dict[agent], + device=self.device, + dtype=torch.bool, + ) + group_terminated[index] = torch.tensor( + terminations_dict[agent], + device=self.device, + dtype=torch.bool, + ) + + if group_info is not None: + agent_info_dict = info_dict[agent] + for agent_info, value in agent_info_dict.items(): + group_info.get(agent_info)[index] = torch.tensor( + value, device=self.device + ) + + elif not self.use_action_mask: + # Dead agent, if we are not masking it out, this is not allowed + raise ValueError( + "Dead agents found in the environment," + " you need to set use_action_mask=True to allow this." + ) + + return tensordict_out + + def _step_parallel( + self, + tensordict: TensorDictBase, + ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + action_dict = {} + for group, agents in self.group_map.items(): + group_action = tensordict.get((group, "action")) + group_action_np = self.input_spec[ + "full_action_spec", group, "action" + ].to_numpy(group_action) + for index, agent in enumerate(agents): + action_dict[agent] = group_action_np[index] + + return self._env.step(action_dict) + + def _step_aec( + self, + tensordict: TensorDictBase, + ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + + for group, agents in self.group_map.items(): + if self.agent_selection in agents: + agent_index = agents.index(self._env.agent_selection) + group_action = tensordict.get((group, "action")) + group_action_np = self.input_spec[ + "full_action_spec", group, "action" + ].to_numpy(group_action) + action = group_action_np[agent_index] + break + + self._env.step(action) + terminations_dict = self._env.terminations + truncations_dict = self._env.truncations + info_dict = self._env.infos + rewards_dict = self._env.rewards + observation_dict = { + agent: self._env.observe(agent) for agent in self.possible_agents + } + return ( + observation_dict, + rewards_dict, + terminations_dict, + truncations_dict, + info_dict, + ) + + def _update_action_mask(self, td, observation_dict, info_dict): + + # Since we remove the action_mask keys we need to copy the data + observation_dict = copy.deepcopy(observation_dict) + info_dict = copy.deepcopy(info_dict) + # In AEC only one agent acts, in parallel env self.agents contains the agents alive + agents_acting = self.agents if self.parallel else [self.agent_selection] + + for group, agents in self.group_map.items(): + if self.has_action_mask[group]: + group_mask = td.get((group, "action_mask")) + group_mask += True + for index, agent in enumerate(agents): + agent_obs = observation_dict[agent] + agent_info = info_dict[agent] + if isinstance(agent_obs, Dict) and "action_mask" in agent_obs: + if agent in agents_acting: + group_mask[index] = torch.tensor( + agent_obs["action_mask"], + device=self.device, + dtype=torch.bool, + ) + del agent_obs["action_mask"] + elif isinstance(agent_info, Dict) and "action_mask" in agent_info: + if agent in agents_acting: + group_mask[index] = torch.tensor( + agent_info["action_mask"], + device=self.device, + dtype=torch.bool, + ) + del agent_info["action_mask"] + + group_action_spec = self.input_spec["full_action_spec", group, "action"] + if isinstance( + group_action_spec, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) + ): + # We update the mask for available actions + group_action_spec.update_mask(group_mask.clone()) + + return observation_dict, info_dict + + def _update_agent_mask(self, td): + if self.use_mask: + # In AEC only one agent acts, in parallel env self.agents contains the agents alive + agents_acting = self.agents if self.parallel else [self.agent_selection] + for group, agents in self.group_map.items(): + group_mask = td.get((group, "mask")) + group_mask += True + + # We now add dead agents to the mask + for index, agent in enumerate(agents): + if agent not in agents_acting: + group_mask[index] = False + + def close(self) -> None: + self._env.close() + + +class PettingZooEnv(PettingZooWrapper): + """PettingZoo Environment. + + To install petting zoo follow the guide `here __`. + + This class is a general torchrl wrapper for all PettingZoo environments. + It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``. + + Let's see how more in details: + + For wrapping ``pettingzoo.ParallelEnv`` provide the name of your petting zoo task (in the ``task`` argument) + and specify ``parallel=True``. This will construct the ``pettingzoo.ParallelEnv`` version of that task + (if it is supported in pettingzoo) and wrap it for torchrl. + In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step. + If the number of agents during the task varies, please set ``use_mask=True``. + ``"mask"`` will be provided + as an output in each group and should be used to mask out dead agents. + The environment will be reset as soon as one agent is done. + + For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument) + and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task + and wrap it for torchrl. + In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act. + For this reason, it is compulsory to set ``use_mask=True`` for this type of environment. + ``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents. + The environment will be reset only when all agents are done. + + If there are any unavailable actions for an agent, + the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"`` + for each group to reflect the latest available actions. This should be passed to a masked distribution during + training. + + As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment. + You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same + neural network. You can split agents in different groups where they are heterogenous or should be processed by + different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time. + + By default, agents in pettingzoo will be grouped by name. + For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agent: TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]))}, + adversary: TensorDict( + fields={ + action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False), + terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([1]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]} + + Otherwise, a group map can be specified or selected from some premade options. + See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should + have its own tensordict (similar to the pettingzoo parallel API). + + Grouping is useful for leveraging vectorisation among agents whose data goes through the same + neural network. + + Args: + task (str): the name of the pettingzoo task to create (for example, "multiwalker_v9"). + parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``. + return_state (bool, optional): whether to return the global state from pettingzoo + (not available in all environments). Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for + input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified + or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + use_mask (bool, optional): whether the environment should output an ``"mask"``. This is compulsory in + wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used + for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``. + categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform + them to categorical or one-hot. + seed (int, optional): the seed. Defaults to ``None``. + + Examples: + >>> # Parallel env + >>> from torchrl.envs.libs.pettingzoo import PettingZooEnv + >>> kwargs = {"n_pistons": 21, "continuous": True} + >>> env = PettingZooEnv( + ... task="pistonball_v6", + ... parallel=True, + ... return_state=True, + ... group_map=None, # Use default (all pistons grouped together) + ... **kwargs, + ... ) + >>> print(env.group_map) + ... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']} + >>> env.rollout(10) + >>> # AEC env + >>> from torchrl.envs.libs.pettingzoo import PettingZooEnv + >>> from torchrl.envs.utils import MarlGroupMapType + >>> env = PettingZooEnv( + ... task="tictactoe_v3", + ... parallel=False, + ... use_mask=True, # Must use it since one player plays at a time + ... group_map=None # # Use default for AEC (one group per player) + ... ) + >>> print(env.group_map) + ... {'player_1': ['player_1'], 'player_2': ['player_2']} + >>> env.rollout(10) + """ + + def __init__( + self, + task: str, + parallel: bool, + return_state: Optional[bool] = False, + group_map: Optional[Union[MarlGroupMapType, Dict[str, List[str]]]] = None, + use_mask: bool = False, + categorical_actions: bool = True, + seed: Optional[int] = None, + **kwargs, + ): + if not _has_pettingzoo: + raise ImportError( + f"pettingzoo python package was not found. Please install this dependency. " + f"More info: {self.git_url}." + ) + kwargs["task"] = task + kwargs["parallel"] = parallel + kwargs["return_state"] = return_state + kwargs["group_map"] = group_map + kwargs["use_mask"] = use_mask + kwargs["categorical_actions"] = categorical_actions + kwargs["seed"] = seed + + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + if "task" not in kwargs: + raise TypeError("Could not find environment key 'task' in kwargs.") + if "parallel" not in kwargs: + raise TypeError("Could not find environment key 'parallel' in kwargs.") + + def _build_env( + self, + task: str, + parallel: bool, + **kwargs, + ) -> Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ]: + self.task_name = task + + from pettingzoo.utils.all_modules import all_environments + + if task not in all_environments: + # Try looking at the literal translation of values + task_module = None + for value in all_environments.values(): + if value.__name__.split(".")[-1] == task: + task_module = value + break + if task_module is None: + raise RuntimeError(f"Specified task not in {_get_envs()}") + else: + task_module = all_environments[task] + + if parallel: + petting_zoo_env = task_module.parallel_env(**kwargs) + else: + petting_zoo_env = task_module.env(**kwargs) + + return super()._build_env(env=petting_zoo_env) diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py new file mode 100644 index 00000000000..a3ee1dfa893 --- /dev/null +++ b/torchrl/envs/libs/robohive.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import importlib +import os +import warnings + +from copy import copy +from pathlib import Path + +import numpy as np +import torch +from tensordict import TensorDict +from tensordict.tensordict import make_tensordict +from torchrl._utils import implement_for +from torchrl.data import UnboundedContinuousTensorSpec +from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv +from torchrl.envs.utils import _classproperty, make_composite_from_td + +_has_gym = importlib.util.find_spec("gym") is not None +_has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym + +if _has_robohive: + os.environ.setdefault("sim_backend", "MUJOCO") + + +class set_directory(object): + """Sets the cwd within the context. + + Args: + path (Path): The path to the cwd + """ + + def __init__(self, path: Path): + self.path = path + self.origin = Path().absolute() + + def __enter__(self): + os.chdir(self.path) + + def __exit__(self, *args, **kwargs): + os.chdir(self.origin) + + def __call__(self, fun): + def new_fun(*args, **kwargs): + with set_directory(Path(self.path)): + return fun(*args, **kwargs) + + return new_fun + + +class _RoboHiveBuild(_AsyncMeta): + def __call__(self, *args, **kwargs): + instance: RoboHiveEnv = super().__call__(*args, **kwargs) + instance._refine_specs() + return instance + + +class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): + """A wrapper for RoboHive gym environments. + + RoboHive is a collection of environments/tasks simulated with the MuJoCo physics engine exposed using the OpenAI-Gym API. + + Github: https://github.com/vikashplus/robohive/ + + RoboHive requires gym 0.13. + + Args: + env_name (str): the environment name to build. + read_info (bool, optional): whether the the info should be parsed. + Defaults to ``True``. + device (torch.device, optional): the device on which the input/output + are expected. Defaults to torch default device. + """ + + env_list = [] + + @_classproperty + def CURR_DIR(cls): + if _has_robohive: + import robohive.envs.multi_task.substeps1 + + return robohive.envs.multi_task.substeps1.CURR_DIR + else: + return None + + @_classproperty + def available_envs(cls): + if not _has_robohive: + return + RoboHiveEnv.register_envs() + yield from cls.env_list + + @classmethod + def register_envs(cls): + if not _has_robohive: + raise ImportError( + "Cannot load robohive from the current virtual environment." + ) + from robohive import robohive_env_suite as robohive_envs + from robohive.utils.prompt_utils import Prompt, set_prompt_verbosity + + set_prompt_verbosity(Prompt.WARN) + cls.env_list += robohive_envs + if not len(robohive_envs): + raise RuntimeError("did not load any environment.") + + @implement_for( + "gym", "0.14", None + ) # make sure gym 0.13 is installed, otherwise raise an exception + def _build_env(self, *args, **kwargs): + raise NotImplementedError( + "Your gym version is too recent, RoboHiveEnv is only compatible with gym 0.13." + ) + + @implement_for( + "gym", "0.13", "0.14" + ) # make sure gym 0.13 is installed, otherwise raise an exception + def _build_env( # noqa: F811 + self, + env_name: str, + from_pixels: bool = False, + pixels_only: bool = False, + **kwargs, + ) -> "gym.core.Env": # noqa: F821 + if from_pixels: + if "cameras" not in kwargs: + warnings.warn( + "from_pixels=True will lead to a registration of ALL available cameras, " + "which may lead to performance issue. " + "Consider passing only the needed cameras through cameras=list_of_cameras. " + "The list of available cameras for a specific environment can be obtained via " + "RobohiveEnv.get_available_cams(env_name)." + ) + kwargs["cameras"] = self.get_available_cams(env_name) + cams = list(kwargs.pop("cameras")) + env_name = self.register_visual_env(cams=cams, env_name=env_name) + + elif "cameras" in kwargs and kwargs["cameras"]: + raise RuntimeError("Got a list of cameras but from_pixels is set to False.") + + self.pixels_only = pixels_only + try: + render_device = int(str(self.device)[-1]) + except ValueError: + render_device = 0 + + if not _has_robohive: + raise ImportError( + f"gym/robohive not found, unable to create {env_name}. " + f"Consider downloading and installing dm_control from" + f" {self.git_url}" + ) + try: + env = self.lib.make( + env_name, + frameskip=self.frame_skip, + device_id=render_device, + return_dict=True, + **kwargs, + ) + self.wrapper_frame_skip = 1 + if env.visual_keys: + from_pixels = bool(len(env.visual_keys)) + else: + from_pixels = False + except TypeError as err: + if "unexpected keyword argument 'frameskip" not in str(err): + raise err + kwargs.pop("framek_skip") + env = self.lib.make( + env_name, return_dict=True, device_id=render_device, **kwargs + ) + self.wrapper_frame_skip = self.frame_skip + # except Exception as err: + # raise RuntimeError(f"Failed to build env {env_name}.") from err + self.from_pixels = from_pixels + self.render_device = render_device + if kwargs.get("read_info", True): + self.set_info_dict_reader(self.read_info) + return env + + @classmethod + def register_visual_env(cls, env_name, cams): + with set_directory(cls.CURR_DIR): + from robohive.envs.env_variants import register_env_variant + + if not len(cams): + raise RuntimeError("Cannot create a visual envs without cameras.") + cams = sorted(cams) + new_env_name = "-".join([cam[:-3] for cam in cams] + [env_name]) + if new_env_name in cls.env_list: + return new_env_name + visual_keys = [f"rgb:{c}:224x224:2d" for c in cams] + register_env_variant( + env_name, + variants={ + "visual_keys": visual_keys, + }, + variant_id=new_env_name, + ) + env_name = new_env_name + cls.env_list += [env_name] + return env_name + + def _refine_specs(self) -> None: # noqa: F821 + env = self._env + self.action_spec = _gym_to_torchrl_spec_transform( + env.action_space, device=self.device + ) + # get a np rollout + rollout = TensorDict({"done": torch.zeros(3, 1)}, [3]) + env.reset() + + def get_obs(): + _dict = {} + obs_dict = copy(env.obs_dict) + if self.from_pixels: + visual = self.env.get_exteroception() + obs_dict.update(visual) + pixel_list = [] + for obs_key in obs_dict: + if obs_key.startswith("rgb"): + pix = obs_dict[obs_key] + if not pix.shape[0] == 1: + pix = pix[None] + pixel_list.append(pix) + elif obs_key in env.obs_keys: + value = env.obs_dict[obs_key] + if not value.shape: + value = value[None] + _dict[obs_key] = value + if pixel_list: + _dict["pixels"] = np.concatenate(pixel_list, 0) + return _dict + + for i in range(3): + _dict = {} + _dict.update(get_obs()) + _dict["action"] = action = env.action_space.sample() + _, r, d, _ = env.step(action) + _dict[("next", "reward")] = r.reshape(1) + _dict[("next", "done")] = [1] + _dict["next"] = get_obs() + rollout[i] = TensorDict(_dict, []) + + observation_spec = make_composite_from_td( + rollout.get("next").exclude("done", "reward")[0] + ) + self.observation_spec = observation_spec + + self.reward_spec = UnboundedContinuousTensorSpec( + shape=(1,), + device=self.device, + ) # default + + rollout = self.rollout(2, return_contiguous=False).get("next") + rollout = rollout.exclude( + self.reward_key, *self.done_keys, *self.observation_spec.keys(True, True) + ) + rollout = rollout[..., 0] + spec = make_composite_from_td(rollout) + self.observation_spec.update(spec) + + def set_from_pixels(self, from_pixels: bool) -> None: + """Sets the from_pixels attribute to an existing environment. + + Args: + from_pixels (bool): new value for the from_pixels attribute + + """ + if from_pixels is self.from_pixels: + return + self.from_pixels = from_pixels + self._refine_specs() + + def read_obs(self, observation): + # the info is missing from the reset + observations = self.env.obs_dict + try: + del observations["t"] + except KeyError: + pass + # recover vec + obsdict = {} + pixel_list = [] + if self.from_pixels: + visual = self.env.get_exteroception() + observations.update(visual) + for key in observations: + if key.startswith("rgb"): + pix = observations[key] + if not pix.shape[0] == 1: + pix = pix[None] + pixel_list.append(pix) + elif key in self._env.obs_keys: + value = observations[key] + if not value.shape: + value = value[None] + obsdict[key] = value # ravel helps with images + # if obsvec: + # obsvec = np.concatenate(obsvec, 0) + if self.from_pixels: + obsdict.update({"pixels": np.concatenate(pixel_list, 0)}) + out = obsdict + return super().read_obs(out) + + def read_info(self, info, tensordict_out): + out = {} + for key, value in info.items(): + if key in ("obs_dict", "done", "reward", *self._env.obs_keys, "act"): + continue + if isinstance(value, dict): + value = {key: _val for key, _val in value.items() if _val is not None} + value = make_tensordict(value, batch_size=[]) + if value is not None: + out[key] = value + tensordict_out.update(out) + tensordict_out.update( + tensordict_out.apply(lambda x: x.reshape((1,)) if not x.shape else x) + ) + return tensordict_out + + def _init_env(self): + pass + + def to(self, *args, **kwargs): + out = super().to(*args, **kwargs) + try: + render_device = int(str(out.device)[-1]) + except ValueError: + render_device = 0 + if render_device != self.render_device: + out._build_env(**self._constructor_kwargs) + return out + + @classmethod + def get_available_cams(cls, env_name): + import gym + + env = gym.make(env_name) + cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)] + return cams diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py new file mode 100644 index 00000000000..906b9c456bd --- /dev/null +++ b/torchrl/envs/libs/smacv2.py @@ -0,0 +1,655 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import importlib +import re + +from typing import Dict, Optional + +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs.common import _EnvWrapper + +from torchrl.envs.utils import _classproperty, ACTION_MASK_ERROR + +_has_smacv2 = importlib.util.find_spec("smacv2") is not None + + +def _get_envs(): + if not _has_smacv2: + raise ImportError("SMAC-v2 is not installed in your virtual environment.") + from smacv2.env.starcraft2.maps import smac_maps + + return list(smac_maps.get_smac_map_registry().keys()) + + +class SMACv2Wrapper(_EnvWrapper): + """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. + + To install the environment follow the following `guide `__. + + Examples: + >>> from torchrl.envs.libs.smacv2 import SMACv2Wrapper + >>> import smacv2 + >>> print(SMACv2Wrapper.available_envs) + ['10gen_terran', '10gen_zerg', '10gen_protoss', '3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', '10m_vs_11m', + '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', + '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] + >>> # You can use old SMAC maps + >>> env = SMACv2Wrapper(smacv2.env.StarCraft2Env(map_name="MMM2"), categorical_actions=False) + >>> print(env.rollout(5)) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> # Or the new features for procedural generation + >>> distribution_config = { + ... "n_units": 5, + ... "n_enemies": 6, + ... "team_gen": { + ... "dist_type": "weighted_teams", + ... "unit_types": ["marine", "marauder", "medivac"], + ... "exception_unit_types": ["medivac"], + ... "weights": [0.5, 0.2, 0.3], + ... "observe": True, + ... }, + ... "start_positions": { + ... "dist_type": "surrounded_and_reflect", + ... "p": 0.5, + ... "n_enemies": 5, + ... "map_x": 32, + ... "map_y": 32, + ... }, + ... } + >>> env = SMACv2Wrapper( + ... smacv2.env.StarCraft2Env( + ... map_name="10gen_terran", + ... capability_config=distribution_config, + ... ) + ... ) + >>> print(env.rollout(4)) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False) + """ + + git_url = "https://github.com/oxwhirl/smacv2" + libname = "smacv2" + + @_classproperty + def available_envs(cls): + if not _has_smacv2: + return + yield from _get_envs() + + def __init__( + self, + env: "smacv2.env.StarCraft2Env" = None, # noqa: F821 + categorical_actions: bool = True, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + self.categorical_actions = categorical_actions + + super().__init__(**kwargs) + + @property + def lib(self): + import smacv2 + + return smacv2 + + def _check_kwargs(self, kwargs: Dict): + import smacv2 + + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, smacv2.env.StarCraft2Env): + raise TypeError("env is not of type 'smacv2.env.StarCraft2Env'.") + + def _build_env( + self, + env: "smacv2.env.StarCraft2Env", # noqa: F821 + ): + if len(self.batch_size): + raise RuntimeError( + f"SMACv2 does not support custom batch_size {self.batch_size}." + ) + + return env + + def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821 + self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} + self.reward_spec = UnboundedContinuousTensorSpec( + shape=torch.Size((1,)), + device=self.device, + ) + self.done_spec = DiscreteTensorSpec( + n=2, + shape=torch.Size((1,)), + dtype=torch.bool, + device=self.device, + ) + self.action_spec = self._make_action_spec() + self.observation_spec = self._make_observation_spec() + + def _init_env(self) -> None: + self._env.reset() + self._update_action_mask() + + def _make_action_spec(self) -> CompositeSpec: + if self.categorical_actions: + action_spec = DiscreteTensorSpec( + self.n_actions, + shape=torch.Size((self.n_agents,)), + device=self.device, + dtype=torch.long, + ) + else: + action_spec = OneHotDiscreteTensorSpec( + self.n_actions, + shape=torch.Size((self.n_agents, self.n_actions)), + device=self.device, + dtype=torch.long, + ) + spec = CompositeSpec( + { + "agents": CompositeSpec( + {"action": action_spec}, shape=torch.Size((self.n_agents,)) + ) + } + ) + return spec + + def _make_observation_spec(self) -> CompositeSpec: + obs_spec = BoundedTensorSpec( + low=-1.0, + high=1.0, + shape=torch.Size([self.n_agents, self.get_obs_size()]), + device=self.device, + dtype=torch.float32, + ) + info_spec = CompositeSpec( + { + "battle_won": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device + ), + "episode_limit": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device + ), + "dead_allies": BoundedTensorSpec( + low=0, + high=self.n_agents, + dtype=torch.long, + device=self.device, + shape=(), + ), + "dead_enemies": BoundedTensorSpec( + low=0, + high=self.n_enemies, + dtype=torch.long, + device=self.device, + shape=(), + ), + } + ) + mask_spec = DiscreteTensorSpec( + 2, + torch.Size([self.n_agents, self.n_actions]), + device=self.device, + dtype=torch.bool, + ) + spec = CompositeSpec( + { + "agents": CompositeSpec( + {"observation": obs_spec, "action_mask": mask_spec}, + shape=torch.Size((self.n_agents,)), + ), + "state": BoundedTensorSpec( + low=-1.0, + high=1.0, + shape=torch.Size((self.get_state_size(),)), + device=self.device, + dtype=torch.float32, + ), + "info": info_spec, + } + ) + return spec + + def _set_seed(self, seed: Optional[int]): + if seed is not None: + raise NotImplementedError( + "Seed cannot be changed once environment was created." + ) + + def get_obs(self): + obs = self._env.get_obs() + return self._to_tensor(obs) + + def get_state(self): + state = self._env.get_state() + return self._to_tensor(state) + + def _to_tensor(self, value): + return torch.tensor(value, device=self.device, dtype=torch.float32) + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + + obs, state = self._env.reset() + + # collect outputs + obs = self._to_tensor(obs) + state = self._to_tensor(state) + info = self.observation_spec["info"].zero() + + mask = self._update_action_mask() + + # build results + agents_td = TensorDict( + {"observation": obs, "action_mask": mask}, batch_size=(self.n_agents,) + ) + tensordict_out = TensorDict( + source={"agents": agents_td, "state": state, "info": info}, + batch_size=(), + device=self.device, + ) + + return tensordict_out + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # perform actions + action = tensordict.get(("agents", "action")) + action_np = self.action_spec.to_numpy(action) + + # Actions are validated by the environment. + try: + reward, done, info = self._env.step(action_np) + except AssertionError as err: + if re.match(r"Agent . cannot perform action .", str(err)): + raise ACTION_MASK_ERROR + else: + raise err + + # collect outputs + obs = self.get_obs() + state = self.get_state() + info = self.observation_spec["info"].encode(info) + actual_keys = info.keys() + for expected_key, spec in self.observation_spec["info"].items(): + if expected_key not in actual_keys: + info[expected_key] = spec.zero() + + reward = torch.tensor( + reward, device=self.device, dtype=torch.float32 + ).unsqueeze(-1) + done = torch.tensor(done, device=self.device, dtype=torch.bool).unsqueeze(-1) + + mask = self._update_action_mask() + + # build results + agents_td = TensorDict( + {"observation": obs, "action_mask": mask}, batch_size=(self.n_agents,) + ) + + tensordict_out = TensorDict( + source={ + "agents": agents_td, + "state": state, + "info": info, + "reward": reward, + "done": done, + "terminated": done.clone(), + }, + batch_size=(), + device=self.device, + ) + + return tensordict_out + + def _update_action_mask(self): + mask = torch.tensor( + self.get_avail_actions(), dtype=torch.bool, device=self.device + ) + self.action_spec.update_mask(mask) + return mask + + def close(self): + # Closes StarCraft II + self._env.close() + + def get_agent_type(self, agent_index: int) -> str: + """Get the agent type string. + + Given the agent index, get its unit type name. + + Args: + agent_index (int): the index of the agent to get the type of + + """ + if agent_index < 0 or agent_index >= self.n_agents: + raise ValueError(f"Agent index out of range, {self.n_agents} available") + + agent_info = self.agents[agent_index] + if agent_info.unit_type == self.marine_id: + return "marine" + elif agent_info.unit_type == self.marauder_id: + return "marauder" + elif agent_info.unit_type == self.medivac_id: + return "medivac" + elif agent_info.unit_type == self.hydralisk_id: + return "hydralisk" + elif agent_info.unit_type == self.zergling_id: + return "zergling" + elif agent_info.unit_type == self.baneling_id: + return "baneling" + elif agent_info.unit_type == self.stalker_id: + return "stalker" + elif agent_info.unit_type == self.colossus_id: + return "colossus" + elif agent_info.unit_type == self.zealot_id: + return "zealot" + else: + raise AssertionError(f"Agent type {agent_info.unit_type} unidentified") + + # This patches the bug in https://github.com/oxwhirl/smacv2/issues/33 + def render(self, mode: str = "human"): + import smacv2 + + if isinstance(self._env, smacv2.env.StarCraftCapabilityEnvWrapper): + return self._env.env.render(mode=mode) + else: + return self._env.render(mode=mode) + + +class SMACv2Env(SMACv2Wrapper): + """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. + + To install the environment follow the following `guide `__. + + Examples: + >>> from torchrl.envs.libs.smacv2 import SMACv2Env + >>> print(SMACv2Env.available_envs) + ['10gen_terran', '10gen_zerg', '10gen_protoss', '3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', '10m_vs_11m', + '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', + '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] + >>> # You can use old SMAC maps + >>> env = SMACv2Env(map_name="MMM2") + >>> print(env.rollout(5) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> # Or the new features for procedural generation + >>> distribution_config = { + ... "n_units": 5, + ... "n_enemies": 6, + ... "team_gen": { + ... "dist_type": "weighted_teams", + ... "unit_types": ["marine", "marauder", "medivac"], + ... "exception_unit_types": ["medivac"], + ... "weights": [0.5, 0.2, 0.3], + ... "observe": True, + ... }, + ... "start_positions": { + ... "dist_type": "surrounded_and_reflect", + ... "p": 0.5, + ... "n_enemies": 5, + ... "map_x": 32, + ... "map_y": 32, + ... }, + ... } + >>> env = SMACv2Env( + ... map_name="10gen_terran", + ... capability_config=distribution_config, + ... categorical_actions=False, + ... ) + >>> print(env.rollout(4)) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False) + """ + + def __init__( + self, + map_name: str, + capability_config: Optional[Dict] = None, + seed: Optional[int] = None, + categorical_actions: bool = True, + **kwargs, + ): + if not _has_smacv2: + raise ImportError( + f"smacv2 python package was not found. Please install this dependency. " + f"More info: {self.git_url}." + ) + kwargs["map_name"] = map_name + kwargs["capability_config"] = capability_config + kwargs["seed"] = seed + kwargs["categorical_actions"] = categorical_actions + + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + if "map_name" not in kwargs: + raise TypeError("Expected 'map_name' to be part of kwargs") + + def _build_env( + self, + map_name: str, + capability_config: Optional[Dict] = None, + seed: Optional[int] = None, + **kwargs, + ) -> "smacv2.env.StarCraft2Env": # noqa: F821 + import smacv2.env + + if capability_config is not None: + env = smacv2.env.StarCraftCapabilityEnvWrapper( + capability_config=capability_config, + map_name=map_name, + seed=seed, + **kwargs, + ) + else: + env = smacv2.env.StarCraft2Env(map_name=map_name, seed=seed, **kwargs) + + return super()._build_env(env) diff --git a/torchrl/envs/libs/unity.py b/torchrl/envs/libs/unity.py new file mode 100644 index 00000000000..5e8a8fa747f --- /dev/null +++ b/torchrl/envs/libs/unity.py @@ -0,0 +1,385 @@ +from warnings import warn + +import numpy as np +import torch +from tensordict.tensordict import TensorDict, TensorDictBase +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + MultiDiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.utils import numpy_to_torch_dtype_dict +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.utils import _classproperty + +IMPORT_ERR = None +try: + from mlagents_envs.base_env import ActionSpec, ActionTuple, BaseEnv, ObservationSpec + from mlagents_envs.environment import SideChannel, UnityEnvironment + + _has_mlagents = True +except ImportError as err: + _has_mlagents = False + IMPORT_ERR = err + + +__all__ = ["UnityWrapper", "UnityEnv"] + + +def _unity_to_torchrl_spec_transform(spec, dtype=None, device="cpu"): + """Maps the Unity specs to the TorchRL specs.""" + if isinstance(spec, ObservationSpec): + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) + dtype = numpy_to_torch_dtype_dict[dtype] + return UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) + elif isinstance(spec, ActionSpec): + if spec.continuous_size == spec.discrete_size == 0: + raise ValueError("No available actions") + d_spec = c_spec = None + if spec.discrete_size == 1: + d_spec = DiscreteTensorSpec( + spec.discrete_branches[0], shape=[spec.discrete_size], device=device + ) + elif spec.discrete_size > 1: + d_spec = MultiDiscreteTensorSpec( + spec.discrete_branches, shape=[spec.discrete_size], device=device + ) + + if spec.continuous_size > 0: + dtype = numpy_to_torch_dtype_dict[dtype] + c_spec = BoundedTensorSpec( + -1, 1, (spec.continuous_size,), dtype=dtype, device=device + ) + + if d_spec and c_spec: + return CompositeSpec(discrete=d_spec, continuous=c_spec) + else: + return d_spec if d_spec else c_spec + else: + raise TypeError(f"Unknown spec of type {type(spec)} passed") + + +class UnityWrapper(_EnvWrapper): + """Unity environment wrapper. + + Examples: + >>> env = UnityWrapper( + ... UnityEnvironment( + ... "<>", + ... side_channels=[], + ... additional_args=[], + ... log_folder=<>, + ... device=device, + ... ) + ... ) + """ + + git_url = "https://github.com/Unity-Technologies/ml-agents" + libname = "mlagents_envs" + + def __init__(self, env=None, **kwargs): + if env is not None: + kwargs["env"] = env + super().__init__(**kwargs) + + def _init_env(self): + pass + + def _compute_num_agents(self, env): + num_agents = 0 + for behavior_name in env.behavior_specs.keys(): + decision_steps, terminal_steps = env.get_steps(behavior_name) + num_agents += len(decision_steps) + len(terminal_steps) + return num_agents + + def _set_seed(self, seed: int | None): + warn( + "Seeding through _set_seed has not been implemented. Please set the " + "seed when you create the environment." + ) + + @_classproperty + def available_envs(cls) -> list[str]: + return [] + + def _build_env(self, env: BaseEnv): + if not env.behavior_specs: + # Take a single step so that the brain information will be sent over + env.step() + self._behavior_names = list(env.behavior_specs.keys()) + self.num_agents = self._compute_num_agents(env) + self._agent_ids = torch.tensor(range(self.num_agents), dtype=torch.int) + self._agent_id_to_behavior_name = {} + return env + + def _make_specs(self, env: BaseEnv) -> None: + observation_specs = [None] * self.num_agents + action_specs = [None] * self.num_agents + reward_specs = [None] * self.num_agents + done_specs = [None] * self.num_agents + valid_mask_specs = [None] * self.num_agents + + for behavior_name, behavior_unity_spec in env.behavior_specs.items(): + decision_steps, terminal_steps = env.get_steps(behavior_name) + for steps in [decision_steps, terminal_steps]: + for agent_id in steps.agent_id: + self._agent_id_to_behavior_name[agent_id] = behavior_name + + observation_specs[agent_id] = CompositeSpec( + { + f"obs_{agent_id}_{i}": _unity_to_torchrl_spec_transform( + spec, dtype=np.dtype("float32"), device=self.device + ) + for i, spec in enumerate( + behavior_unity_spec.observation_specs + ) + } + ) + action_specs[agent_id] = _unity_to_torchrl_spec_transform( + behavior_unity_spec.action_spec, + dtype=np.dtype("int32"), + device=self.device, + ) + reward_specs[agent_id] = UnboundedContinuousTensorSpec( + shape=[1], device=self.device + ) + done_specs[agent_id] = DiscreteTensorSpec( + n=2, shape=[1], dtype=torch.bool, device=self.device + ) + valid_mask_specs[agent_id] = DiscreteTensorSpec( + n=2, shape=[1], dtype=torch.bool, device=self.device + ) + + self.observation_spec = CompositeSpec( + { + "agents": CompositeSpec( + {"observation": torch.stack(observation_specs, dim=0)}, + shape=(self.num_agents,), + ) + } + ) + self.action_spec = CompositeSpec( + { + "agents": CompositeSpec( + {"action": torch.stack(action_specs, dim=0)}, + shape=(self.num_agents,), + ) + } + ) + self.reward_spec = CompositeSpec( + { + "agents": CompositeSpec( + {"reward": torch.stack(reward_specs, dim=0)}, + shape=(self.num_agents,), + ) + } + ) + self.done_spec = CompositeSpec( + { + "agents": CompositeSpec( + {"done": torch.stack(done_specs, dim=0)}, shape=(self.num_agents,) + ) + } + ) + self.valid_mask_spec = CompositeSpec( + { + "agents": CompositeSpec( + {"valid_mask": torch.stack(valid_mask_specs, dim=0)}, + shape=(self.num_agents,), + ) + } + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" + ) + + def _check_kwargs(self, kwargs: dict): + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, BaseEnv): + raise TypeError("env is not of type 'mlagents_envs.base_env.BaseEnv'.") + if "frame_skip" in kwargs and kwargs["frame_skip"] != 1: + # FIXME: Add support for this. + raise ValueError( + "Currently, frame_skip is not supported for Unity environments." + ) + + def agent_id_to_behavior_name(self, agent_id: int): + return self._agent_id_to_behavior_name[agent_id] + + def read_obs(self, agent_id, obs): + return self.observation_spec["agents", "observation"][agent_id].encode( + {f"obs_{agent_id}_{i}": observation for i, observation in enumerate(obs)}, + ) + + def read_reward(self, agent_id, reward): + return self.reward_spec[agent_id].encode(reward) + + def read_valid_mask(self, agent_id, valid): + return self.valid_mask_spec["agents", "valid_mask"][agent_id].encode(valid) + + def read_action(self, action): + action = self.action_spec.to_numpy(action, safe=False) + # Actions are defined to be 2D arrays with the first dimension + # used for the number of agents in the game and the second + # dimension used for the action. + if isinstance(action, dict): + action = { + k: np.reshape(v, (1, np.prod(v.shape))) for k, v in action.items() + } + else: + action = np.reshape(action, (1, np.prod(action.shape))) + + if isinstance(self.action_spec, CompositeSpec): + action = ActionTuple(action["continuous"], action["discrete"]) + elif isinstance(self.action_spec, DiscreteTensorSpec | MultiDiscreteTensorSpec): + action = ActionTuple(None, action) + else: + action = ActionTuple(action, None) + return action + + def read_done(self, agent_id, done): + return self.done_spec[agent_id].encode(done) + + def _get_next_tensordict(self): + agent_tds = [None] * self.num_agents + seen_agent_ids = set() + + for behavior_name_ in self.behavior_specs.keys(): + decision_steps, terminal_steps = self.get_steps(behavior_name_) + for i, steps in enumerate([decision_steps, terminal_steps]): + for agent_id in steps.agent_id: + agent_id = int(agent_id) + step = steps[agent_id] + done = False if i == 0 else True + seen_agent_ids.add(agent_id) + + agent_td = TensorDict( + source={ + "observation": self.read_obs(agent_id, step.obs), + "reward": self.read_reward(agent_id, step.reward), + "done": self.read_done(agent_id, done), + "valid_mask": self.read_valid_mask(agent_id, True), + }, + batch_size=[], + ) + agent_tds[agent_id] = agent_td + + missing_agents = set(range(self.num_agents)) - seen_agent_ids + for missing_agent in missing_agents: + agent_td = TensorDict( + source={ + "observation": self.observation_spec["agents", "observation"][ + missing_agent + ].zero(), + "reward": self.reward_spec[missing_agent].zero(), + "done": self.done_spec[missing_agent].zero(), + "valid_mask": self.read_valid_mask(agent_id, False), + }, + batch_size=[], + ) + agent_tds[missing_agent] = agent_td + + agents_td = torch.stack(agent_tds, dim=0) + tensordict_out = TensorDict(source={"agents": agents_td}, batch_size=[]) + return tensordict_out + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # FIXME: Figure out why tensordict["agents", "valid_mask"] and tensordict["agents", "done"] + # have different shapes which require us to squeeze. + eligible_agent_mask = torch.logical_and( + torch.squeeze(tensordict["agents", "valid_mask"]), + torch.logical_not(torch.squeeze(tensordict["agents", "done"])), + ) + agent_ids = self._agent_ids[eligible_agent_mask] + actions = tensordict["agents", "action"].unsqueeze(-1)[eligible_agent_mask] + for action, agent_id in zip(actions, agent_ids): + unity_action = self.read_action(action) + self.set_action_for_agent( + self.agent_id_to_behavior_name(agent_id.item()), + agent_id.item(), + unity_action, + ) + self._env.step() + tensordict_out = self._get_next_tensordict() + return tensordict_out.select().set("next", tensordict_out) + + def _reset(self, tensordict: TensorDictBase | None = None, **kwargs): + self._env.reset(**kwargs) + tensordict_out = self._get_next_tensordict() + return tensordict_out + + +class UnityEnv(UnityWrapper): + """Unity environment wrapper. + + Examples: + >>> env = UnityEnv( + ... "<>", + ... side_channels=[], + ... additional_args=[], + ... log_folder=<>, + ... device=device, + ... ) + """ + + def __init__( + self, + file_name: str | None = None, + seed: int = 0, + no_graphics: bool = False, + timeout_wait: int = 60, + side_channels: list[SideChannel] | None = None, + log_folder: str | None = None, + **kwargs, + ): + kwargs["file_name"] = file_name + kwargs["seed"] = seed + kwargs["no_graphics"] = no_graphics + kwargs["timeout_wait"] = timeout_wait + kwargs["side_channels"] = side_channels + kwargs["log_folder"] = log_folder + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: dict): + if "file_name" not in kwargs: + raise TypeError("Could not find environment key 'file_name' in kwargs.") + + def _build_env( + self, + file_name: str | None = None, + seed: int = 0, + no_graphics: bool = False, + timeout_wait: int = 60, + side_channels: list[SideChannel] | None = None, + log_folder: str | None = None, + **env_kwargs, + ): + if not _has_mlagents: + raise RuntimeError( + f"Unity MLAgents not found, unable to create environment. " + f"Consider downloading and installing Unity MLAgents from" + f" {self.git_url}" + ) + self.file_name = file_name + return super()._build_env( + UnityEnvironment( + file_name, + seed=seed, + no_graphics=no_graphics, + timeout_wait=timeout_wait, + side_channels=side_channels, + log_folder=log_folder, + **env_kwargs, + ) + ) + + def __repr__(self): + return f"{super().__repr__()}(file_name={self.file_name})" diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 04620804d8a..6cc4b54705b 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -1,4 +1,10 @@ -from typing import Dict, List, Optional, Union +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import importlib.util + +from typing import Dict, Optional, Union import torch from tensordict.tensordict import TensorDict, TensorDictBase @@ -7,28 +13,24 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, + LazyStackedCompositeSpec, UnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper, EnvBase from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend -from torchrl.envs.utils import _selective_unsqueeze - -IMPORT_ERR = None -try: - import vmas +from torchrl.envs.utils import _classproperty, _selective_unsqueeze - _has_vmas = True +_has_vmas = importlib.util.find_spec("vmas") is not None -except ImportError as err: - _has_vmas = False - IMPORT_ERR = err __all__ = ["VmasWrapper", "VmasEnv"] -def _get_envs() -> List: +def _get_envs(): if not _has_vmas: - return [] + raise ImportError("VMAS is not installed in your virtual environment.") + import vmas + all_scenarios = vmas.scenarios + vmas.mpe_scenarios + vmas.debug_scenarios # TODO heterogenous spaces # For now torchrl does not support heterogenous spaces (Tple(Box)) so many OpenAI MPE scenarios do not work @@ -67,59 +69,80 @@ class VmasWrapper(_EnvWrapper): >>> print(env.rollout(10)) TensorDict( fields={ - action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), - done: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - info: TensorDict( + agents: TensorDict( fields={ - agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), + info: TensorDict( + fields={ + agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 10, 5]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), + done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - done: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.bool, is_shared=False), - info: TensorDict( + agents: TensorDict( fields={ - agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + info: TensorDict( + fields={ + agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 10, 5]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32, 10]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32, 10]), device=cpu, is_shared=False) - """ git_url = "https://github.com/proroklab/VectorizedMultiAgentSimulator" libname = "vmas" - available_envs = _get_envs() + + @property + def lib(self): + import vmas + + return vmas + + @_classproperty + def available_envs(cls): + if not _has_vmas: + return + yield from _get_envs() def __init__( - self, env: "vmas.simulator.environment.environment.Environment" = None, **kwargs + self, + env: "vmas.simulator.environment.environment.Environment" = None, # noqa + categorical_actions: bool = True, + **kwargs, ): if env is not None: kwargs["env"] = env if "device" in kwargs.keys() and kwargs["device"] != str(env.device): raise TypeError("Env device is different from vmas device") kwargs["device"] = str(env.device) - super().__init__(**kwargs) - - @property - def lib(self): - return vmas + self.categorical_actions = categorical_actions + super().__init__(**kwargs, allow_done_after_reset=True) def _build_env( self, - env: "vmas.simulator.environment.environment.Environment", + env: "vmas.simulator.environment.environment.Environment", # noqa from_pixels: bool = False, pixels_only: bool = False, ): @@ -149,7 +172,7 @@ def _build_env( @set_gym_backend("gym") def _make_specs( - self, env: "vmas.simulator.environment.environment.Environment" + self, env: "vmas.simulator.environment.environment.Environment" # noqa ) -> None: # TODO heterogenous spaces @@ -160,24 +183,38 @@ def _make_specs( info_specs = [] for agent_index, agent in enumerate(self.agents): action_specs.append( - _gym_to_torchrl_spec_transform( - self.action_space[agent_index], - categorical_action_encoding=True, - device=self.device, + CompositeSpec( + { + "action": _gym_to_torchrl_spec_transform( + self.action_space[agent_index], + categorical_action_encoding=self.categorical_actions, + device=self.device, + remap_state_to_observation=False, + ) # shape = (n_actions_per_agent,) + }, ) - ) # shape = (n_actions_per_agent,) + ) observation_specs.append( - _gym_to_torchrl_spec_transform( - self.observation_space[agent_index], - device=self.device, + CompositeSpec( + { + "observation": _gym_to_torchrl_spec_transform( + self.observation_space[agent_index], + device=self.device, + remap_state_to_observation=False, + ) # shape = (n_obs_per_agent,) + }, ) - ) # shape = (n_obs_per_agent,) + ) reward_specs.append( - UnboundedContinuousTensorSpec( - shape=torch.Size((1,)), - device=self.device, + CompositeSpec( + { + "reward": UnboundedContinuousTensorSpec( + shape=torch.Size((1,)), + device=self.device, + ) # shape = (1,) + } ) - ) # shape = (1,) + ) agent_info = self.scenario.info(agent) if len(agent_info): info_specs.append( @@ -185,7 +222,7 @@ def _make_specs( { key: UnboundedContinuousTensorSpec( shape=_selective_unsqueeze( - value, batch_size=torch.Size((self.num_envs,)) + value, batch_size=self.batch_size ).shape[1:], device=self.device, dtype=torch.float32, @@ -198,13 +235,17 @@ def _make_specs( # Create multi-agent specs multi_agent_action_spec = torch.stack( action_specs, dim=0 - ) # UnboundedContinuousTensorSpec with shape = (n_agents, n_actions_per_agent) + ) # shape = (n_agents, n_actions_per_agent) multi_agent_observation_spec = torch.stack( observation_specs, dim=0 - ) # UnboundedContinuousTensorSpec with shape = (n_agents, n_obs_per_agent) + ) # shape = (n_agents, n_obs_per_agent) multi_agent_reward_spec = torch.stack( reward_specs, dim=0 - ) # UnboundedContinuousTensorSpec with shape = (n_agents, 1) + ) # shape = (n_agents, 1) + + self.het_specs = isinstance( + multi_agent_observation_spec, LazyStackedCompositeSpec + ) or isinstance(multi_agent_action_spec, LazyStackedCompositeSpec) done_spec = DiscreteTensorSpec( n=2, @@ -213,32 +254,15 @@ def _make_specs( device=self.device, ) # shape = (1,) - self.unbatched_action_spec = CompositeSpec( - { - "agents": CompositeSpec( - {"action": multi_agent_action_spec}, shape=(self.n_agents,) - ) - } - ) + self.unbatched_action_spec = CompositeSpec({"agents": multi_agent_action_spec}) self.unbatched_observation_spec = CompositeSpec( - { - "agents": CompositeSpec( - {"observation": multi_agent_observation_spec}, - shape=(self.n_agents,), - ) - } + {"agents": multi_agent_observation_spec} ) if len(info_specs): multi_agent_info_spec = torch.stack(info_specs, dim=0) self.unbatched_observation_spec[("agents", "info")] = multi_agent_info_spec - self.unbatched_reward_spec = CompositeSpec( - { - "agents": CompositeSpec( - {"reward": multi_agent_reward_spec}, shape=(self.n_agents,) - ) - } - ) + self.unbatched_reward_spec = CompositeSpec({"agents": multi_agent_reward_spec}) self.unbatched_done_spec = done_spec self.action_spec = self.unbatched_action_spec.expand( @@ -255,6 +279,8 @@ def _make_specs( ) def _check_kwargs(self, kwargs: Dict): + vmas = self.lib + if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -299,20 +325,24 @@ def _reset( agent_td = TensorDict( source={ - "agents": { - "observation": agent_obs, - }, + "observation": agent_obs, }, - batch_size=(self.num_envs,), + batch_size=self.batch_size, device=self.device, ) if agent_info is not None: - agent_td.set(("agents", "info"), agent_info) + agent_td.set("info", agent_info) agent_tds.append(agent_td) - tensordict_out = torch.stack(agent_tds, dim=1).to_tensordict() - tensordict_out.batch_size = self.batch_size - tensordict_out.set("done", dones) + agent_tds = torch.stack(agent_tds, dim=1) + if not self.het_specs: + agent_tds = agent_tds.to_tensordict() + tensordict_out = TensorDict( + source={"agents": agent_tds, "done": dones, "terminated": dones.clone()}, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict_out def _step( @@ -334,29 +364,36 @@ def _step( agent_td = TensorDict( source={ - "agents": { - "observation": agent_obs, - "reward": agent_rew, - }, + "observation": agent_obs, + "reward": agent_rew, }, - batch_size=(self.num_envs,), + batch_size=self.batch_size, device=self.device, ) if agent_info is not None: - agent_td.set(("agents", "info"), agent_info) + agent_td.set("info", agent_info) agent_tds.append(agent_td) - tensordict_out = torch.stack(agent_tds, dim=1).to_tensordict() - tensordict_out.batch_size = self.batch_size - tensordict_out.set("done", dones) + agent_tds = torch.stack(agent_tds, dim=1) + if not self.het_specs: + agent_tds = agent_tds.to_tensordict() + tensordict_out = TensorDict( + source={"agents": agent_tds, "done": dones, "terminated": dones.clone()}, + batch_size=self.batch_size, + device=self.device, + ) - return tensordict_out.select().set("next", tensordict_out) + return tensordict_out - def read_obs(self, observations: torch.Tensor) -> torch.Tensor: - observations = _selective_unsqueeze( - observations, batch_size=torch.Size((self.num_envs,)) + def read_obs( + self, observations: Union[Dict, torch.Tensor] + ) -> Union[Dict, torch.Tensor]: + if isinstance(observations, torch.Tensor): + return _selective_unsqueeze(observations, batch_size=self.batch_size) + return TensorDict( + source={key: self.read_obs(value) for key, value in observations.items()}, + batch_size=self.batch_size, ) - return observations def read_info(self, infos: Dict[str, torch.Tensor]) -> torch.Tensor: if len(infos) == 0: @@ -364,25 +401,29 @@ def read_info(self, infos: Dict[str, torch.Tensor]) -> torch.Tensor: infos = TensorDict( source={ key: _selective_unsqueeze( - value.to(torch.float32), batch_size=torch.Size((self.num_envs,)) + value.to(torch.float32), batch_size=self.batch_size ) for key, value in infos.items() }, - batch_size=torch.Size((self.num_envs,)), + batch_size=self.batch_size, device=self.device, ) return infos def read_done(self, done): - done = _selective_unsqueeze(done, batch_size=torch.Size((self.num_envs,))) + done = _selective_unsqueeze(done, batch_size=self.batch_size) return done def read_reward(self, rewards): - rewards = _selective_unsqueeze(rewards, batch_size=torch.Size((self.num_envs,))) + rewards = _selective_unsqueeze(rewards, batch_size=self.batch_size) return rewards def read_action(self, action): + if not self.continuous_actions and not self.categorical_actions: + action = self.unbatched_action_spec["agents", "action"].to_categorical( + action + ) agent_actions = [] for i in range(self.n_agents): agent_actions.append(action[:, i, ...]) @@ -416,45 +457,55 @@ class VmasEnv(VmasWrapper): >>> print(env.rollout(10)) TensorDict( fields={ - action: Tensor(torch.Size([5, 32, 10, 2]), dtype=torch.float64), - done: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.bool), - info: TensorDict( + agents: TensorDict( fields={ - cohesion_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32), - collision_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32), - separation_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32), - velocity_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32)}, - batch_size=torch.Size([5, 32, 10]), + action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), + info: TensorDict( + fields={ + agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 10, 5]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), + done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - info: TensorDict( + agents: TensorDict( fields={ - cohesion_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32), - collision_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32), - separation_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32), - velocity_rew: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32)}, - batch_size=torch.Size([5, 32, 10]), + info: TensorDict( + fields={ + agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 10, 5]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 10, 5]), device=cpu, is_shared=False), - observation: Tensor(torch.Size([5, 32, 10, 18]), dtype=torch.float32)}, - batch_size=torch.Size([5, 32, 10]), + done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32, 10]), device=cpu, is_shared=False), - observation: Tensor(torch.Size([5, 32, 10, 18]), dtype=torch.float32), - reward: Tensor(torch.Size([5, 32, 10, 1]), dtype=torch.float32)}, - batch_size=torch.Size([5, 32, 10]), + terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32, 10]), device=cpu, is_shared=False) """ def __init__( self, - scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], + scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa num_envs: int, continuous_actions: bool = True, max_steps: Optional[int] = None, + categorical_actions: bool = True, seed: Optional[int] = None, **kwargs, ): @@ -462,12 +513,13 @@ def __init__( raise ImportError( f"vmas python package was not found. Please install this dependency. " f"More info: {self.git_url}." - ) from IMPORT_ERR + ) kwargs["scenario"] = scenario kwargs["num_envs"] = num_envs kwargs["continuous_actions"] = continuous_actions kwargs["max_steps"] = max_steps kwargs["seed"] = seed + kwargs["categorical_actions"] = categorical_actions super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): @@ -478,13 +530,15 @@ def _check_kwargs(self, kwargs: Dict): def _build_env( self, - scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], + scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa num_envs: int, continuous_actions: bool, max_steps: Optional[int], seed: Optional[int], **scenario_kwargs, - ) -> "vmas.simulator.environment.environment.Environment": + ) -> "vmas.simulator.environment.environment.Environment": # noqa + vmas = self.lib + self.scenario_name = scenario from_pixels = scenario_kwargs.pop("from_pixels", False) pixels_only = scenario_kwargs.pop("pixels_only", False) diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 1a63b0f5c45..5952132ca19 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -16,7 +16,7 @@ from torchrl.envs.common import EnvBase -class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): +class ModelBasedEnvBase(EnvBase): """Basic environnement for Model Based RL algorithms. Wrapper around the model of the MBRL algorithm. @@ -160,16 +160,12 @@ def _step( ) else: tensordict_out = self.world_model(tensordict_out) - # Step requires a done flag. No sense for MBRL so we set it to False - if "done" not in self.world_model.out_keys: - tensordict_out["done"] = torch.zeros( - tensordict_out.shape, - dtype=torch.bool, - device=tensordict_out.device, - ) - return tensordict_out.select().set( - "next", - tensordict_out.select(*self.observation_spec.keys(), "reward", "done"), + # done can be missing, it will be filled by `step` + return tensordict_out.select( + *self.observation_spec.keys(), + *self.full_done_spec.keys(), + *self.full_reward_spec.keys(), + strict=False, ) @abc.abstractmethod diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index bf9ac50d4c9..e36ddf9e02a 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -12,7 +12,7 @@ from torchrl.data.tensor_specs import CompositeSpec from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs import EnvBase +from torchrl.envs.common import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 5ee87e2c0eb..f486a0f793f 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -3,16 +3,21 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .gym_transforms import EndOfLifeTransform from .r3m import R3MTransform from .rlhf import KLRewardTransform from .transforms import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, CenterCrop, + ClipTransform, Compose, + DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, + DTypeCastTransform, ExcludeTransform, FiniteTensorDictCheck, FlattenObservation, @@ -23,6 +28,7 @@ NoopResetEnv, ObservationNorm, ObservationTransform, + PermuteTransform, PinMemoryTransform, RandomCropTensorDict, RenameTransform, diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py new file mode 100644 index 00000000000..a67e526fc25 --- /dev/null +++ b/torchrl/envs/transforms/gym_transforms.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Gym-specific transforms.""" +import warnings + +import torch +import torchrl.objectives.common +from tensordict import TensorDictBase +from tensordict.utils import expand_as_right, NestedKey +from torchrl.data.tensor_specs import UnboundedDiscreteTensorSpec + +from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform + + +class EndOfLifeTransform(Transform): + """Registers the end-of-life signal from a Gym env with a `lives` method. + + Proposed by DeepMind for the DQN and co. It helps value estimation. + + Args: + eol_key (NestedKey, optional): the key where the end-of-life signal should + be written. Defaults to ``"end-of-life"``. + done_key (NestedKey, optional): a "done" key in the parent env done_spec, + where the done value can be retrieved. This key must be unique and its + shape must match the shape of the end-of-life entry. Defaults to ``"done"``. + eol_attribute (str, optional): the location of the "lives" in the gym env. + Defaults to ``"unwrapped.ale.lives"``. Supported attribute types are + integer/array-like objects or callables that return these values. + + .. note:: + This transform should be used with gym envs that have a ``env.unwrapped.ale.lives``. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs.transforms.transforms import TransformedEnv + >>> env = GymEnv("ALE/Breakout-v5") + >>> env.rollout(100) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([100]), + device=cpu, + is_shared=False), + pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([100]), + device=cpu, + is_shared=False) + >>> eol_transform = EndOfLifeTransform() + >>> env = TransformedEnv(env, eol_transform) + >>> env.rollout(100) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + eol: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + end-of-life: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False), + pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([100]), + device=cpu, + is_shared=False), + pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([100]), + device=cpu, + is_shared=False) + + The typical usage of this transform is to replace the "done" state by "end-of-life" + within the loss module. The end-of-life signal isn't registered within the ``done_spec`` + because it should not instruct the env to reset. + + Examples: + >>> from torchrl.objectives import DQNLoss + >>> module = torch.nn.Identity() # used as a placeholder + >>> loss = DQNLoss(module, action_space="categorical") + >>> loss.set_keys(done="end-of-life", terminated="end-of-life") + >>> # equivalently + >>> eol_transform.register_keys(loss) + """ + + NO_PARENT_ERR = "The {} transform is being executed without a parent env. This is currently not supported." + + def __init__( + self, + eol_key: NestedKey = "end-of-life", + lives_key: NestedKey = "lives", + done_key: NestedKey = "done", + eol_attribute="unwrapped.ale.lives", + ): + super().__init__(in_keys=[done_key], out_keys=[eol_key, lives_key]) + self.eol_key = eol_key + self.lives_key = lives_key + self.done_key = done_key + self.eol_attribute = eol_attribute.split(".") + + def _get_lives(self): + from torchrl.envs.libs.gym import GymWrapper + + base_env = self.parent.base_env + if not isinstance(base_env, GymWrapper): + warnings.warn( + f"The base_env is not a gym env. Compatibility of {type(self)} is not guaranteed with " + f"environment types that do not inherit from GymWrapper.", + category=UserWarning, + ) + # getattr falls back on _env by default + lives = getattr(base_env, self.eol_attribute[0]) + for att in self.eol_attribute[1:]: + if isinstance(lives, list): + # For SerialEnv (and who knows Parallel one day) + lives = [getattr(_lives, att) for _lives in lives] + else: + lives = getattr(lives, att) + if callable(lives): + lives = lives() + elif isinstance(lives, list) and all(callable(_lives) for _lives in lives): + lives = torch.tensor([_lives() for _lives in lives]) + return lives + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict + + def _step(self, tensordict, next_tensordict): + parent = self.parent + if parent is None: + raise RuntimeError(self.NO_PARENT_ERR.format(type(self))) + + lives = self._get_lives() + end_of_life = torch.tensor( + tensordict.get(self.lives_key) < lives, device=self.parent.device + ) + try: + done = next_tensordict.get(self.done_key) + except KeyError: + raise KeyError( + f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. " + f"Make sure to pass the appropriate done_key to the {type(self)} transform." + ) + end_of_life = expand_as_right(end_of_life, done) | done + next_tensordict.set(self.eol_key, end_of_life) + next_tensordict.set(self.lives_key, lives) + return next_tensordict + + def reset(self, tensordict): + parent = self.parent + if parent is None: + raise RuntimeError(self.NO_PARENT_ERR.format(type(self))) + lives = self._get_lives() + end_of_life = False + tensordict.set( + self.eol_key, + torch.tensor(end_of_life).expand( + parent.full_done_spec[self.done_key].shape + ), + ) + tensordict.set(self.lives_key, lives) + return tensordict + + def transform_observation_spec(self, observation_spec): + full_done_spec = self.parent.output_spec["full_done_spec"] + observation_spec[self.eol_key] = full_done_spec[self.done_key].clone() + observation_spec[self.lives_key] = UnboundedDiscreteTensorSpec( + self.parent.batch_size, + device=self.parent.device, + dtype=torch.int64, + ) + return observation_spec + + def register_keys(self, loss_or_advantage: "torchrl.objectives.common.LossModule"): + """Registers the end-of-life key at appropriate places within the loss. + + Args: + loss_or_advantage (torchrl.objectives.LossModule or torchrl.objectives.value.ValueEstimatorBase): a module to instruct what the end-of-life key is. + + """ + loss_or_advantage.set_keys(done=self.eol_key, terminated=self.eol_key) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self))) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index c9b2053f832..bb180ecaa9d 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -2,17 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from copy import deepcopy +from copy import copy, deepcopy import torch - from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( make_functional, ProbabilisticTensorDictModule, repopulate_module, + TensorDictParams, ) from tensordict.utils import is_seq_of_nested_key +from torch import nn from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.transforms.transforms import Transform @@ -33,6 +34,14 @@ class KLRewardTransform(Transform): reward should be fetched. Defaults to ``"reward"``. out_keys (str or list of str/tuples of str): the output key where the reward should be written. Defaults to ``"reward"``. + requires_grad (bool, optional): if ``True``, the frozen parameters will + consist of differentiable clones of the original params. + Defaults to ``False``. + + .. note:: If the parameters are not differentiable (default), they will *not* + follow the module when dtype or device casting operations will be called + (such as :meth:`~.cuda`, :meth:`~.to` etc.). When ``requires_grad=True``, + casting operations will work as expected. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -65,8 +74,7 @@ class KLRewardTransform(Transform): >>> # check that rewards have been modified >>> assert (td.get(("next", "reward")) != td.get(("next", "reward_kl"))).all() - .. note:: - Because the KL formulat is not always available and the parameters of the + .. note:: Because the KL formulat is not always available and the parameters of the original distribution may not have been recorded, we use a stochastic estimate of the KL divergence. @@ -80,28 +88,27 @@ def __init__( coef=1.0, in_keys=None, out_keys=None, + requires_grad=False, ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS if out_keys is None: - out_keys = in_keys - if not isinstance(in_keys, list): - in_keys = [in_keys] - if not isinstance(out_keys, list): - out_keys = [out_keys] - if not is_seq_of_nested_key(in_keys) or not is_seq_of_nested_key(out_keys): + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) + if not is_seq_of_nested_key(self.in_keys) or not is_seq_of_nested_key( + self.out_keys + ): raise ValueError( - f"invalid in_keys / out_keys:\nin_keys={in_keys} \nout_keys={out_keys}" + f"invalid in_keys / out_keys:\nin_keys={self.in_keys} \nout_keys={self.out_keys}" ) - if len(in_keys) != 1 or len(out_keys) != 1: + if len(self.in_keys) != 1 or len(self.out_keys) != 1: raise ValueError( - f"Only one in_key/out_key is allowed, got in_keys={in_keys}, out_keys={out_keys}." + f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}." ) - super().__init__(in_keys=in_keys, out_keys=out_keys) # for convenience, convert out_keys to tuples - self.out_keys = [ + self._out_keys = [ out_key if isinstance(out_key, tuple) else (out_key,) - for out_key in self.out_keys + for out_key in self._out_keys ] # update the in_keys for dispatch etc @@ -115,7 +122,23 @@ def __init__( repopulate_module(actor, params) # we need to register these params as buffer to have `to` and similar # methods work properly - self.frozen_params = params.clone().detach() + + def _make_detached_param(x): + + if isinstance(x, nn.Parameter): + # we need an nn.Parameter since some modules (RNN) require nn.Parameters + return nn.Parameter(x.data.clone(), requires_grad=requires_grad) + elif x.requires_grad: + raise ValueError( + "Encountered a value that requires gradients but is not an nn.Parameter instance." + ) + return x.clone() + + self.frozen_params = params.apply(_make_detached_param) + if requires_grad: + # includes the frozen params/buffers in the module parameters/buffers + self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) + # self._buffers["actor_params"] = params.clone().detach() # find the sample log-prob key @@ -152,7 +175,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(("next", *self.out_keys[0]), reward + self.coef * kl) return tensordict - _step = _call + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + with tensordict.unlock_(): + return self._call(tensordict.set("next", next_tensordict)).pop("next") + forward = _call def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: @@ -166,23 +194,23 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: parent = self.parent reward_spec = UnboundedContinuousTensorSpec( device=output_spec.device, - shape=output_spec["_reward_spec"][parent.reward_key].shape, + shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) - output_spec["_reward_spec"] = CompositeSpec( + output_spec["full_reward_spec"] = CompositeSpec( {parent.reward_key: reward_spec}, - shape=output_spec["_reward_spec"].shape, + shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": parent = self.parent reward_spec = UnboundedContinuousTensorSpec( device=output_spec.device, - shape=output_spec["_reward_spec"][parent.reward_key].shape, + shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) # then we need to populate the output keys - observation_spec = output_spec["_observation_spec"] + observation_spec = output_spec["full_observation_spec"] observation_spec[out_key] = reward_spec else: - observation_spec = output_spec["_observation_spec"] + observation_spec = output_spec["full_observation_spec"] reward_spec = UnboundedContinuousTensorSpec( device=output_spec.device, shape=observation_spec[in_key].shape ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0f17c37b5f4..1e4a277c220 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8,13 +8,17 @@ import collections import multiprocessing as mp import warnings -from copy import copy, deepcopy +from copy import copy +from functools import wraps from textwrap import indent -from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Union + +import numpy as np import torch from tensordict import unravel_key, unravel_key_list +from tensordict._tensordict import _unravel_key_to_tuple from tensordict.nn import dispatch from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, NestedKey @@ -27,15 +31,16 @@ ContinuousBox, DEVICE_TYPING, DiscreteTensorSpec, + MultiDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, ) -from torchrl.envs.common import EnvBase, make_tensordict +from torchrl.envs.common import _EnvPostInit, EnvBase, make_tensordict from torchrl.envs.transforms import functional as F from torchrl.envs.transforms.utils import check_finite -from torchrl.envs.utils import _sort_keys, step_mdp +from torchrl.envs.utils import _replace_last, _sort_keys, step_mdp from torchrl.objectives.value.functional import reward2go try: @@ -61,14 +66,17 @@ def interpolation_fn(interpolation): # noqa: D103 IMAGE_KEYS = ["pixels"] _MAX_NOOPS_TRIALS = 10 -FORWARD_NOT_IMPLEMENTED = "class {} cannot be executed without a parent" "environment." +FORWARD_NOT_IMPLEMENTED = "class {} cannot be executed without a parent environment." def _apply_to_composite(function): + @wraps(function) def new_fun(self, observation_spec): if isinstance(observation_spec, CompositeSpec): d = observation_spec._specs - for in_key, out_key in zip(self.in_keys, self.out_keys): + in_keys = self.in_keys + out_keys = self.out_keys + for in_key, out_key in zip(in_keys, out_keys): if in_key in observation_spec.keys(True, True): d[out_key] = function(self, observation_spec[in_key].clone()) return CompositeSpec( @@ -90,13 +98,15 @@ def _apply_to_composite_inv(function): # tensor is not updated) an out_key that does not match the in_key has # no effect on the spec. def new_fun(self, input_spec): - action_spec = input_spec["_action_spec"].clone() - state_spec = input_spec["_state_spec"] + action_spec = input_spec["full_action_spec"].clone() + state_spec = input_spec["full_state_spec"] if state_spec is None: state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) else: state_spec = state_spec.clone() - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + in_keys_inv = self.in_keys_inv + out_keys_inv = self.out_keys_inv + for in_key, out_key in zip(in_keys_inv, out_keys_inv): if in_key != out_key: # we only change the input spec if the key is the same continue @@ -105,8 +115,8 @@ def new_fun(self, input_spec): elif in_key in state_spec.keys(True, True): state_spec[out_key] = function(self, state_spec[in_key].clone()) return CompositeSpec( - _state_spec=state_spec, - _action_spec=action_spec, + full_state_spec=state_spec, + full_action_spec=action_spec, shape=input_spec.shape, device=input_spec.device, ) @@ -143,31 +153,82 @@ class Transform(nn.Module): def __init__( self, - in_keys: Sequence[NestedKey], + in_keys: Sequence[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, out_keys_inv: Optional[Sequence[NestedKey]] = None, ): super().__init__() - if isinstance(in_keys, str): - in_keys = [in_keys] - self.in_keys = in_keys - if out_keys is None: - out_keys = copy(self.in_keys) self.out_keys = out_keys - if in_keys_inv is None: - in_keys_inv = [] self.in_keys_inv = in_keys_inv - if out_keys_inv is None: - out_keys_inv = copy(self.in_keys_inv) self.out_keys_inv = out_keys_inv self._missing_tolerance = False self.__dict__["_container"] = None self.__dict__["_parent"] = None + @property + def in_keys(self): + in_keys = self.__dict__.get("_in_keys", None) + if in_keys is None: + return [] + return in_keys + + @in_keys.setter + def in_keys(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._in_keys = value + + @property + def out_keys(self): + out_keys = self.__dict__.get("_out_keys", None) + if out_keys is None: + return [] + return out_keys + + @out_keys.setter + def out_keys(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._out_keys = value + + @property + def in_keys_inv(self): + in_keys_inv = self.__dict__.get("_in_keys_inv", None) + if in_keys_inv is None: + return [] + return in_keys_inv + + @in_keys_inv.setter + def in_keys_inv(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._in_keys_inv = value + + @property + def out_keys_inv(self): + out_keys_inv = self.__dict__.get("_out_keys_inv", None) + if out_keys_inv is None: + return [] + return out_keys_inv + + @out_keys_inv.setter + def out_keys_inv(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._out_keys_inv = value + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: - """Resets a tranform if it is stateful.""" + """Resets a transform if it is stateful.""" return tensordict def init(self, tensordict) -> None: @@ -181,7 +242,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None: """ raise NotImplementedError( - f"{self.__class__.__name__}_apply_transform is not coded. If the transform is coded in " + f"{self.__class__.__name__}._apply_transform is not coded. If the transform is coded in " "transform._call, make sure that this method is called instead of" "transform.forward, which is reserved for usage inside nn.Modules" "or appended to a replay buffer." @@ -199,8 +260,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """ for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in tensordict.keys(include_nested=True): - observation = self._apply_transform(tensordict.get(in_key)) + value = tensordict.get(in_key, default=None) + if value is not None: + observation = self._apply_transform(value) tensordict.set( out_key, observation, @@ -215,17 +277,17 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in tensordict.keys(include_nested=True): - observation = self._apply_transform(tensordict.get(in_key)) - tensordict.set( - out_key, - observation, - ) + data = tensordict.get(in_key, None) + if data is not None: + data = self._apply_transform(data) + tensordict.set(out_key, data) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: """The parent method of a transform during the ``env.step`` execution. This method should be overwritten whenever the :meth:`~._step` needs to be @@ -237,29 +299,30 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: :meth:`~._step` will only be called by :meth:`TransformedEnv.step` and not by :meth:`TransformedEnv.reset`. + Args: + tensordict (TensorDictBase): data at time t + next_tensordict (TensorDictBase): data at time t+1 + + Returns: the data at t+1 """ - next_tensordict = tensordict.get("next") next_tensordict = self._call(next_tensordict) - tensordict.set("next", next_tensordict) - return tensordict + return next_tensordict - def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: if self.invertible: raise NotImplementedError else: - return obs + return state def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: # # We create a shallow copy of the tensordict to avoid that changes are # # exposed to the user: we'd like that the input keys remain unchanged # # in the originating script if they're being transformed. for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key in tensordict.keys(include_nested=True): - item = self._inv_apply_transform(tensordict.get(in_key)) - tensordict.set( - out_key, - item, - ) + data = tensordict.get(in_key, None) + if data is not None: + item = self._inv_apply_transform(data) + tensordict.set(out_key, item) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") @@ -270,11 +333,15 @@ def inv(self, tensordict: TensorDictBase) -> TensorDictBase: out = self._inv_call(tensordict.clone(False)) return out + def transform_env_device(self, device: torch.device): + """Transforms the device of the parent env.""" + return device + def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: """Transforms the output spec such that the resulting spec matches transform mapping. This method should generally be left untouched. Changes should be implemented using - :meth:`~.transform_observation_spec`, :meth:`~.transform_reward_spec` and :meth:`~.transform_done_spec`. + :meth:`~.transform_observation_spec`, :meth:`~.transform_reward_spec` and :meth:`~.transformfull_done_spec`. Args: output_spec (TensorSpec): spec before the transform @@ -283,16 +350,16 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: """ output_spec = output_spec.clone() - output_spec["_observation_spec"] = self.transform_observation_spec( - output_spec["_observation_spec"] + output_spec["full_observation_spec"] = self.transform_observation_spec( + output_spec["full_observation_spec"] ) - if "_reward_spec" in output_spec.keys(): - output_spec["_reward_spec"] = self.transform_reward_spec( - output_spec["_reward_spec"] + if "full_reward_spec" in output_spec.keys(): + output_spec["full_reward_spec"] = self.transform_reward_spec( + output_spec["full_reward_spec"] ) - if "_done_spec" in output_spec.keys(): - output_spec["_done_spec"] = self.transform_done_spec( - output_spec["_done_spec"] + if "full_done_spec" in output_spec.keys(): + output_spec["full_done_spec"] = self.transform_done_spec( + output_spec["full_done_spec"] ) return output_spec @@ -370,8 +437,49 @@ def clone(self): self_copy.__dict__.update(state) return self_copy + @property + def container(self): + """Returns the env containing the transform. + + Examples: + >>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter())) + >>> env.transform[0].container is env + True + """ + if "_container" not in self.__dict__: + raise AttributeError("transform parent uninitialized") + container = self.__dict__["_container"] + if container is None: + return container + while not isinstance(container, EnvBase): + # if it's not an env, it should be a Compose transform + if not isinstance(container, Compose): + raise ValueError( + "A transform parent must be either another Compose transform or an environment object." + ) + compose = container + container = compose.__dict__.get("_container", None) + return container + @property def parent(self) -> Optional[EnvBase]: + """Returns the parent env of the transform. + + The parent env is the env that contains all the transforms up until the current one. + + Examples: + >>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter())) + >>> env.transform[1].parent + TransformedEnv( + env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu), + transform=Compose( + RewardSum(keys=['reward']))) + + """ if self.__dict__.get("_parent", None) is None: if "_container" not in self.__dict__: raise AttributeError("transform parent uninitialized") @@ -385,26 +493,7 @@ def parent(self) -> Optional[EnvBase]: raise ValueError( "A transform parent must be either another Compose transform or an environment object." ) - compose = container - if compose.__dict__["_container"]: - # the parent of the compose must be a TransformedEnv - compose_parent = TransformedEnv( - compose.__dict__["_container"].base_env - ) - if compose_parent.transform is not compose: - comp_parent_trans = compose_parent.transform.clone() - else: - comp_parent_trans = None - out = TransformedEnv( - compose_parent.base_env, - transform=comp_parent_trans, - ) - for orig_trans in compose.transforms: - if orig_trans is self: - break - transform = orig_trans.clone() - transform.reset_parent() - out.append_transform(transform) + out, _ = container._rebuild_up_to(self) elif isinstance(container, TransformedEnv): out = TransformedEnv(container.base_env) else: @@ -427,7 +516,15 @@ def to(self, *args, **kwargs): return super().to(*args, **kwargs) -class TransformedEnv(EnvBase): +class _TEnvPostInit(_EnvPostInit): + def __call__(self, *args, **kwargs): + instance: EnvBase = super(_EnvPostInit, self).__call__(*args, **kwargs) + # we skip the materialization of the specs, because this can't be done with lazy + # transforms such as ObservationNorm. + return instance + + +class TransformedEnv(EnvBase, metaclass=_TEnvPostInit): """A transformed_in environment. Args: @@ -538,7 +635,11 @@ def transform(self, transform: Transform): @property def device(self) -> bool: - return self.base_env.device + device = self.base_env.device + if self.transform is None: + # during init, the device is checked + return device + return self.transform.transform_env_device(device) @device.setter def device(self, value): @@ -569,8 +670,14 @@ def _inplace_update(self): @property def output_spec(self) -> TensorSpec: """Observation spec of the transformed environment.""" - if self.__dict__.get("_output_spec", None) is None or not self.cache_specs: + if not self.cache_specs or self.__dict__.get("_output_spec", None) is None: output_spec = self.base_env.output_spec.clone() + + # remove cached key values + self.__dict__["_done_keys"] = None + self.__dict__["_reward_keys"] = None + self.__dict__["_reset_keys"] = None + output_spec.unlock_() output_spec = self.transform.transform_output_spec(output_spec) output_spec.lock_() @@ -580,11 +687,6 @@ def output_spec(self) -> TensorSpec: output_spec = self.__dict__.get("_output_spec", None) return output_spec - @property - def action_spec(self) -> TensorSpec: - """Action spec of the transformed environment.""" - return self.input_spec[("_action_spec", *self.action_key)] - @property def input_spec(self) -> TensorSpec: """Action spec of the transformed environment.""" @@ -599,40 +701,14 @@ def input_spec(self) -> TensorSpec: input_spec = self.__dict__.get("_input_spec", None) return input_spec - @property - def reward_spec(self) -> TensorSpec: - """Reward spec of the transformed environment.""" - return self.output_spec[("_reward_spec", *self.reward_key)] - - @property - def observation_spec(self) -> TensorSpec: - """Observation spec of the transformed environment.""" - observation_spec = self.output_spec["_observation_spec"] - if observation_spec is None: - observation_spec = CompositeSpec(device=self.device, shape=self.batch_size) - return observation_spec - - @property - def state_spec(self) -> TensorSpec: - """State spec of the transformed environment.""" - state_spec = self.input_spec["_state_spec"] - if state_spec is None: - state_spec = CompositeSpec(device=self.device, shape=self.batch_size) - return state_spec - - @property - def done_spec(self) -> TensorSpec: - """Done spec of the transformed environment.""" - return self.output_spec[("_done_spec", *self.done_key)] - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) tensordict_in = self.transform.inv(tensordict) - tensordict_out = self.base_env._step(tensordict_in) + next_tensordict = self.base_env._step(tensordict_in) + self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict) # we want the input entries to remain unchanged - tensordict_out = tensordict.update(tensordict_out) - tensordict_out = self.transform._step(tensordict_out) - return tensordict_out + next_tensordict = self.transform._step(tensordict, next_tensordict) + return next_tensordict def set_seed( self, seed: Optional[int] = None, static_seed: bool = False @@ -646,8 +722,19 @@ def _set_seed(self, seed: Optional[int]): def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: - tensordict = tensordict.clone(recurse=False) - out_tensordict = self.base_env.reset(tensordict=tensordict, **kwargs) + # We must avoid modifying the original tensordict so a shallow copy is necessary. + # We just select the input data and reset signal, which is all we need. + tensordict = tensordict.select( + *self.reset_keys, *self.state_spec.keys(True, True), strict=False + ) + out_tensordict = self.base_env._reset(tensordict=tensordict, **kwargs) + self.base_env._complete_done(self.base_env.full_done_spec, out_tensordict) + if tensordict is not None: + # the transform may need to read previous info during reset. + # For instance, we may need to pass the step_count for partial resets. + # We update the copy of tensordict with the new data, instead of + # the contrary because newer data prevails. + out_tensordict = tensordict.update(out_tensordict) out_tensordict = self.transform.reset(out_tensordict) mt_mode = self.transform.missing_tolerance @@ -656,6 +743,12 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): self.set_missing_tolerance(mt_mode) return out_tensordict + def _complete_done( + cls, done_spec: CompositeSpec, data: TensorDictBase + ) -> TensorDictBase: + # This step has already been completed. We assume the transform module do their job correctly. + return data + def state_dict(self, *args, **kwargs) -> OrderedDict: state_dict = self.transform.state_dict(*args, **kwargs) return state_dict @@ -722,23 +815,23 @@ def insert_transform(self, index: int, transform: Transform) -> None: self._erase_metadata() def __getattr__(self, attr: str) -> Any: - if attr in self.__dir__(): + try: return super().__getattr__( attr ) # make sure that appropriate exceptions are raised - elif attr.startswith("__"): + except Exception as err: + if attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + elif "base_env" in self.__dir__(): + base_env = self.__getattr__("base_env") + return getattr(base_env, attr) raise AttributeError( - "passing built-in private methods is " - f"not permitted with type {type(self)}. " - f"Got attribute {attr}." - ) - elif "base_env" in self.__dir__(): - base_env = self.__getattr__("base_env") - return getattr(base_env, attr) - - raise AttributeError( - f"env not set in {self.__class__.__name__}, cannot access {attr}" - ) + f"env not set in {self.__class__.__name__}, cannot access {attr}" + ) from err def __repr__(self) -> str: env_str = indent(f"env={self.base_env}", 4 * " ") @@ -819,7 +912,7 @@ class Compose(Transform): """ def __init__(self, *transforms: Transform): - super().__init__(in_keys=[]) + super().__init__() self.transforms = nn.ModuleList(transforms) for t in transforms: t.set_container(self) @@ -841,16 +934,23 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = t(tensordict) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: for t in self.transforms: - tensordict = t._step(tensordict) - return tensordict + next_tensordict = t._step(tensordict, next_tensordict) + return next_tensordict def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: for t in reversed(self.transforms): tensordict = t._inv_call(tensordict) return tensordict + def transform_env_device(self, device: torch.device): + for t in self.transforms: + device = t.transform_env_device(device) + return device + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: for t in self.transforms[::-1]: input_spec = t.transform_input_spec(input_spec) @@ -957,6 +1057,39 @@ def set_missing_tolerance(self, mode=False): t.set_missing_tolerance(mode) super().set_missing_tolerance(mode) + def _rebuild_up_to(self, final_transform): + container = self.__dict__["_container"] + + if isinstance(container, Compose): + out, parent_compose = container._rebuild_up_to(self) + if out is None: + # returns None if there is no parent env + return None, None + elif isinstance(container, TransformedEnv): + out = TransformedEnv(container.base_env) + elif container is None: + # returns None if there is no parent env + return None, None + else: + raise ValueError(f"Container of type {type(container)} isn't supported.") + + if final_transform not in self.transforms: + raise ValueError(f"Cannot rebuild with transform {final_transform}.") + list_of_transforms = [] + for orig_trans in self.transforms: + if orig_trans is final_transform: + break + transform = orig_trans.clone() + transform.reset_parent() + list_of_transforms.append(transform) + if isinstance(container, Compose): + parent_compose.append(Compose(*list_of_transforms)) + return out, parent_compose[-1] + elif isinstance(container, TransformedEnv): + for t in list_of_transforms: + out.append_transform(t) + return out, out.transform + class ToTensorImage(ObservationTransform): """Transforms a numpy-like image (W x H x C) to a pytorch image (C x W x H). @@ -999,6 +1132,8 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS # default + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) self.from_int = from_int self.unsqueeze = unsqueeze @@ -1042,11 +1177,135 @@ def _should_unsqueeze(self, observation_like: torch.FloatTensor | TensorSpec): def _pixel_observation(self, spec: TensorSpec) -> None: if isinstance(spec.space, ContinuousBox): - spec.space.maximum = self._apply_transform(spec.space.maximum) - spec.space.minimum = self._apply_transform(spec.space.minimum) + spec.space.high = self._apply_transform(spec.space.high) + spec.space.low = self._apply_transform(spec.space.low) return spec +class ClipTransform(Transform): + """A transform to clip input (state, action) or output (observation, reward) values. + + This transform can take multiple input or output keys but only one value per + transform. If multiple clipping values are needed, several transforms should + be appended one after the other. + + Args: + in_keys (list of NestedKeys): input entries (read) + out_keys (list of NestedKeys): input entries (write) + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. + + Keyword Args: + low (scalar, optional): the lower bound of the clipped space. + high (scalar, optional): the higher bound of the clipped space. + + .. note:: Providing just one of the arguments ``low`` or ``high`` is permitted, + but at least one must be provided. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> base_env = GymEnv("Pendulum-v1") + >>> env = TransformedEnv(base_env, ClipTransform(in_keys=['observation'], low=-1, high=0.1)) + >>> r = env.rollout(100) + >>> assert (r["observation"] <= 0.1).all() + """ + + def __init__( + self, + in_keys=None, + out_keys=None, + in_keys_inv=None, + out_keys_inv=None, + *, + low=None, + high=None, + ): + if in_keys is None: + in_keys = [] + if out_keys is None: + out_keys = copy(in_keys) + if in_keys_inv is None: + in_keys_inv = [] + if out_keys_inv is None: + out_keys_inv = copy(in_keys_inv) + super().__init__(in_keys, out_keys, in_keys_inv, out_keys_inv) + if low is None and high is None: + raise TypeError("Either one or both of `high` and `low` must be provided.") + + def check_val(val): + if (isinstance(val, torch.Tensor) and val.numel() > 1) or ( + isinstance(val, np.ndarray) and val.size > 1 + ): + raise TypeError( + f"low and high must be scalars or None. Got low={low} and high={high}." + ) + if val is None: + return None, None, torch.finfo(torch.get_default_dtype()).max + if not isinstance(val, torch.Tensor): + val = torch.tensor(val) + if not val.dtype.is_floating_point: + val = val.float() + eps = torch.finfo(val.dtype).resolution + ext = torch.finfo(val.dtype).max + return val, eps, ext + + low, low_eps, low_min = check_val(low) + high, high_eps, high_max = check_val(high) + if low is not None and high is not None and low >= high: + raise ValueError("`low` must be stricly lower than `high`.") + self.register_buffer("low", low) + self.low_eps = low_eps + self.low_min = -low_min + self.register_buffer("high", high) + self.high_eps = high_eps + self.high_max = high_max + + def _apply_transform(self, obs: torch.Tensor) -> None: + if self.low is None: + return obs.clamp_max(self.high) + elif self.high is None: + return obs.clamp_min(self.low) + return obs.clamp(self.low, self.high) + + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + if self.low is None: + return state.clamp_max(self.high) + elif self.high is None: + return state.clamp_min(self.low) + return state.clamp(self.low, self.high) + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + return BoundedTensorSpec( + shape=observation_spec.shape, + device=observation_spec.device, + dtype=observation_spec.dtype, + high=self.high + self.high_eps if self.high is not None else self.high_max, + low=self.low - self.low_eps if self.low is not None else self.low_min, + ) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + for key in self.in_keys: + if key in self.parent.reward_keys: + spec = self.parent.output_spec["full_reward_spec"][key] + self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec( + shape=spec.shape, + device=spec.device, + dtype=spec.dtype, + high=self.high + self.high_eps + if self.high is not None + else self.high_max, + low=self.low - self.low_eps + if self.low is not None + else self.low_min, + ) + return self.parent.output_spec["full_reward_spec"] + + # No need to transform the input spec since the outside world won't see the difference + # def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + # ... + + class TargetReturn(Transform): """Sets a target return for the agent to achieve in the environment. @@ -1132,17 +1391,17 @@ def __init__( self.mode = mode def reset(self, tensordict: TensorDict): - init_target_return = torch.full( - size=(*tensordict.batch_size, 1), - fill_value=self.target_return, - dtype=torch.float32, - device=tensordict.device, - ) for out_key in self.out_keys: target_return = tensordict.get(out_key, default=None) if target_return is None: + init_target_return = torch.full( + size=(*tensordict.batch_size, 1), + fill_value=self.target_return, + dtype=torch.float32, + device=tensordict.device, + ) target_return = init_target_return tensordict.set( @@ -1162,18 +1421,25 @@ def _call(self, tensordict: TensorDict) -> TensorDict: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: for out_key in self.out_keys: - tensordict.set(("next", out_key), tensordict.get(out_key)) - return super()._step(tensordict) + next_tensordict.set(out_key, tensordict.get(out_key)) + return super()._step(tensordict, next_tensordict) def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor ) -> torch.Tensor: + if target_return.shape != reward.shape: + raise ValueError( + f"The shape of the reward ({reward.shape}) and target return ({target_return.shape}) must match." + ) if self.mode == "reduce": target_return = target_return - reward return target_return elif self.mode == "constant": + target_return = target_return return target_return else: raise ValueError("Unknown mode: {}".format(self.mode)) @@ -1190,15 +1456,15 @@ def transform_observation_spec( raise ValueError( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) - - target_return_spec = BoundedTensorSpec( - minimum=-float("inf"), - maximum=self.target_return, - shape=self.parent.reward_spec.shape, - dtype=self.parent.reward_spec.dtype, - device=self.parent.reward_spec.device, - ) - observation_spec[self.out_keys[0]] = target_return_spec + for key in self.out_keys: + target_return_spec = BoundedTensorSpec( + low=-float("inf"), + high=self.target_return, + shape=self.parent.reward_spec.shape, + dtype=self.parent.reward_spec.dtype, + device=self.parent.reward_spec.device, + ) + observation_spec[key] = target_return_spec return observation_spec @@ -1221,6 +1487,8 @@ def __init__( ): if in_keys is None: in_keys = ["reward"] + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) clamp_min_tensor = ( clamp_min if isinstance(clamp_min, Tensor) else torch.tensor(clamp_min) @@ -1275,6 +1543,8 @@ def __init__( ): if in_keys is None: in_keys = ["reward"] + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @@ -1316,6 +1586,8 @@ def __init__( ) if in_keys is None: in_keys = IMAGE_KEYS # default + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) self.w = int(w) self.h = int(h) @@ -1344,9 +1616,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: space = observation_spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) - observation_spec.shape = space.minimum.shape + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape else: observation_spec.shape = self._apply_transform( torch.zeros(observation_spec.shape) @@ -1384,6 +1656,8 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS # default + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) self.w = w self.h = h if h else w @@ -1396,9 +1670,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: space = observation_spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) - observation_spec.shape = space.minimum.shape + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape else: observation_spec.shape = self._apply_transform( torch.zeros(observation_spec.shape) @@ -1438,6 +1712,8 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS # default + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) if not allow_positive_dim and first_dim >= 0: raise ValueError( @@ -1475,9 +1751,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec space = observation_spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) - observation_spec.shape = space.minimum.shape + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape else: observation_spec.shape = self._apply_transform( torch.zeros(observation_spec.shape) @@ -1524,6 +1800,12 @@ def __init__( ): if in_keys is None: in_keys = [] # default + if out_keys is None: + out_keys = copy(in_keys) + if in_keys_inv is None: + in_keys_inv = [] # default + if out_keys_inv is None: + out_keys_inv = copy(in_keys_inv) super().__init__( in_keys=in_keys, out_keys=out_keys, @@ -1556,9 +1838,9 @@ def _inv_apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def _transform_spec(self, spec: TensorSpec): space = spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) - spec.shape = space.minimum.shape + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + spec.shape = space.low.shape else: spec.shape = self._apply_transform(torch.zeros(spec.shape)).shape return spec @@ -1566,9 +1848,9 @@ def _transform_spec(self, spec: TensorSpec): def _inv_transform_spec(self, spec: TensorSpec) -> None: space = spec.space if isinstance(space, ContinuousBox): - space.minimum = self._inv_apply_transform(space.minimum) - space.maximum = self._inv_apply_transform(space.maximum) - spec.shape = space.minimum.shape + space.low = self._inv_apply_transform(space.low) + space.high = self._inv_apply_transform(space.high) + spec.shape = space.low.shape else: spec.shape = self._inv_apply_transform(torch.zeros(spec.shape)).shape return spec @@ -1633,6 +1915,162 @@ def squeeze_dim(self): _inv_apply_transform = UnsqueezeTransform._apply_transform +class PermuteTransform(Transform): + """Permutation transform. + + Permutes input tensors along the desired dimensions. The permutations + must be provided along the feature dimension (not batch dimension). + + Args: + dims (list of int): the permuted order of the dimensions. Must be a reordering + of the dims ``[-(len(dims)), ..., -1]``. + in_keys (list of NestedKeys): input entries (read). + out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if + not provided. + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. Defaults to ``in_keys_in`` if + not provided. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> base_env = GymEnv("ALE/Pong-v5") + >>> base_env.rollout(2) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False) + >>> env = TransformedEnv(base_env, PermuteTransform((-1, -3, -2), in_keys=["pixels"])) + >>> env.rollout(2) # channels are at the end + TensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False) + + """ + + def __init__( + self, + dims, + in_keys=None, + out_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ): + if in_keys is None: + in_keys = [] + if out_keys is None: + out_keys = copy(in_keys) + if in_keys_inv is None: + in_keys_inv = [] + if out_keys_inv is None: + out_keys_inv = copy(in_keys_inv) + + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ) + # check dims + self.dims = dims + if sorted(dims) != list(range(-len(dims), 0)): + raise ValueError( + f"Only tailing dims with negative indices are supported by {self.__class__.__name__}. Got {dims} instead." + ) + + @staticmethod + def _invert_permute(p): + def _find_inv(i): + for j, _p in enumerate(p): + if _p < 0: + inv = True + _p = len(p) + _p + else: + inv = False + if i == _p: + if inv: + return j - len(p) + else: + return j + else: + # unreachable + raise RuntimeError + + return [_find_inv(i) for i in range(len(p))] + + def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: + observation = observation.permute( + *list(range(observation.ndimension() - len(self.dims))), *self.dims + ) + return observation + + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + permuted_dims = self._invert_permute(self.dims) + state = state.permute( + *list(range(state.ndimension() - len(self.dims))), *permuted_dims + ) + return state + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + observation_spec = self._edit_space(observation_spec) + observation_spec.shape = torch.Size( + [ + *observation_spec.shape[: -len(self.dims)], + *[observation_spec.shape[dim] for dim in self.dims], + ] + ) + return observation_spec + + @_apply_to_composite_inv + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + permuted_dims = self._invert_permute(self.dims) + input_spec = self._edit_space_inv(input_spec) + input_spec.shape = torch.Size( + [ + *input_spec.shape[: -len(permuted_dims)], + *[input_spec.shape[dim] for dim in permuted_dims], + ] + ) + return input_spec + + def _edit_space(self, spec: TensorSpec) -> None: + if isinstance(spec.space, ContinuousBox): + spec.space.high = self._apply_transform(spec.space.high) + spec.space.low = self._apply_transform(spec.space.low) + return spec + + def _edit_space_inv(self, spec: TensorSpec) -> None: + if isinstance(spec.space, ContinuousBox): + spec.space.high = self._inv_apply_transform(spec.space.high) + spec.space.low = self._inv_apply_transform(spec.space.low) + return spec + + class GrayScale(ObservationTransform): """Turns a pixel observation to grayscale.""" @@ -1643,7 +2081,9 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS - super(GrayScale, self).__init__(in_keys=in_keys, out_keys=out_keys) + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = F.rgb_to_grayscale(observation) @@ -1653,9 +2093,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: space = observation_spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) - observation_spec.shape = space.minimum.shape + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape else: observation_spec.shape = self._apply_transform( torch.zeros(observation_spec.shape) @@ -1674,15 +2114,15 @@ class ObservationNorm(ObservationTransform): Args: loc (number or tensor): location of the affine transform scale (number or tensor): scale of the affine transform - in_keys (seuqence of NestedKey, optional): entries to be normalized. Defaults to ["observation", "pixels"]. + in_keys (sequence of NestedKey, optional): entries to be normalized. Defaults to ["observation", "pixels"]. All entries will be normalized with the same values: if a different behaviour is desired (e.g. a different normalization for pixels and states) different :obj:`ObservationNorm` objects should be used. - out_keys (seuqence of NestedKey, optional): output entries. Defaults to the value of `in_keys`. - in_keys_inv (seuqence of NestedKey, optional): ObservationNorm also supports inverse transforms. This will + out_keys (sequence of NestedKey, optional): output entries. Defaults to the value of `in_keys`. + in_keys_inv (sequence of NestedKey, optional): ObservationNorm also supports inverse transforms. This will only occur if a list of keys is provided to :obj:`in_keys_inv`. If none is provided, only the forward transform will be called. - out_keys_inv (seuqence of NestedKey, optional): output entries for the inverse transform. + out_keys_inv (sequence of NestedKey, optional): output entries for the inverse transform. Defaults to the value of `in_keys_inv`. standard_normal (bool, optional): if ``True``, the transform will be @@ -1734,10 +2174,23 @@ def __init__( standard_normal: bool = False, ): if in_keys is None: + warnings.warn( + "Not passing in_keys to ObservationNorm will soon be deprecated. " + "Ensure you specify the entries to be normalized", + category=DeprecationWarning, + ) in_keys = [ "observation", "pixels", ] + + if out_keys is None: + out_keys = copy(in_keys) + if in_keys_inv is None: + in_keys_inv = [] + if out_keys_inv is None: + out_keys_inv = copy(in_keys_inv) + super().__init__( in_keys=in_keys, out_keys=out_keys, @@ -1895,7 +2348,7 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: loc = self.loc return obs * scale + loc - def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: if self.loc is None or self.scale is None: raise RuntimeError( "Loc/Scale have not been initialized. Either pass in values in the constructor " @@ -1904,26 +2357,26 @@ def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: if not self.standard_normal: loc = self.loc scale = self.scale - return (obs - loc) / scale + return (state - loc) / scale else: scale = self.scale loc = self.loc - return obs * scale + loc + return state * scale + loc @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: space = observation_spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) return observation_spec @_apply_to_composite_inv def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: space = input_spec.space if isinstance(space, ContinuousBox): - space.minimum = self._apply_transform(space.minimum) - space.maximum = self._apply_transform(space.maximum) + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) return input_spec def __repr__(self) -> str: @@ -1954,9 +2407,9 @@ class CatFrames(ObservationTransform): dim (int): dimension along which concatenate the observations. Should be negative, to ensure that it is compatible with environments of different batch_size. - in_keys (seuqence of NestedKey, optional): keys pointing to the frames that have + in_keys (sequence of NestedKey, optional): keys pointing to the frames that have to be concatenated. Defaults to ["pixels"]. - out_keys (seuqence of NestedKey, optional): keys pointing to where the output + out_keys (sequence of NestedKey, optional): keys pointing to where the output has to be written. Defaults to the value of `in_keys`. padding (str, optional): the padding method. One of ``"same"`` or ``"zeros"``. Defaults to ``"same"``, ie. the first value is uesd for padding. @@ -2017,11 +2470,16 @@ class CatFrames(ObservationTransform): >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all() + .. note:: :class:`~CatFrames` currently only supports ``"done"`` + signal at the root. Nested ``done``, + such as those found in MARL settings, are currently not supported. + If this feature is needed, please raise an issue on TorchRL repo. + """ inplace = False _CAT_DIM_ERR = ( - "dim must be > 0 to accomodate for tensordict of " + "dim must be < 0 to accomodate for tensordict of " "different batch-sizes (since negative dims are batch invariant)." ) ACCEPTED_PADDING = {"same", "zeros"} @@ -2037,9 +2495,11 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) self.N = N - if dim > 0: + if dim >= 0: raise ValueError(self._CAT_DIM_ERR) self.dim = dim if padding not in self.ACCEPTED_PADDING: @@ -2116,7 +2576,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: for in_key, out_key in zip(self.in_keys, self.out_keys): # Lazy init of buffers buffer_name = f"_cat_buffers_{in_key}" - data = tensordict[in_key] + data = tensordict.get(in_key) d = data.size(self.dim) buffer = getattr(self, buffer_name) if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): @@ -2153,9 +2613,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: space = observation_spec.space if isinstance(space, ContinuousBox): - space.minimum = torch.cat([space.minimum] * self.N, self.dim) - space.maximum = torch.cat([space.maximum] * self.N, self.dim) - observation_spec.shape = space.minimum.shape + space.low = torch.cat([space.low] * self.N, self.dim) + space.high = torch.cat([space.high] * self.N, self.dim) + observation_spec.shape = space.low.shape else: shape = list(observation_spec.shape) shape[self.dim] = self.N * shape[self.dim] @@ -2212,7 +2672,6 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # If so, we must add an offset data = tensordict.get(in_key) if isinstance(in_key, tuple) and in_key[0] == "next": - # let's get the out_key we have already processed prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]] prev_val = tensordict.get(prev_out_key) @@ -2287,11 +2746,15 @@ def __init__( loc: Union[float, torch.Tensor], scale: Union[float, torch.Tensor], in_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Optional[Sequence[NestedKey]] = None, standard_normal: bool = False, ): if in_keys is None: in_keys = ["reward"] - super().__init__(in_keys=in_keys) + if out_keys is None: + out_keys = copy(in_keys) + + super().__init__(in_keys=in_keys, out_keys=out_keys) if not isinstance(standard_normal, torch.Tensor): standard_normal = torch.tensor(standard_normal) self.register_buffer("standard_normal", standard_normal) @@ -2348,22 +2811,112 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: forward = _call -class DoubleToFloat(Transform): - """Maps actions float to double before they are called on the environment. +class DTypeCastTransform(Transform): + """Casts one dtype to another for selected keys. + + Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided + during construction, the class behaviour will change: + + * If the keys are provided, those entries and those entries only will be + transformed from ``dtype_in`` to ``dtype_out`` entries; + * If the keys are not provided and the object is within an environment + register of transforms, the input and output specs that have a dtype + set to ``dtype_in`` will be used as in_keys_inv / in_keys respectively. + * If the keys are not provided and the object is used without an + environment, the ``forward`` / ``inverse`` pass will scan through the + input tensordict for all ``dtype_in`` values and map them to a ``dtype_out`` + tensor. For large data structures, this can impact performance as this + scanning doesn't come for free. The keys to be + transformed will not be cached. + Note that, in this case, the out_keys (resp. + out_keys_inv) cannot be passed as the order on which the keys are processed + cannot be anticipated precisely. Args: - in_keys (sequence of NestedKey, optional): list of double keys to be converted to - float before being exposed to external objects and functions. - in_keys_inv (sequence of NestedKey, optional): list of float keys to be converted to - double before being passed to the contained base_env or storage. + dtype_in (torch.dtype): the input dtype (from the env). + dtype_out (torch.dtype): the output dtype (for model training). + in_keys (sequence of NestedKey, optional): list of ``dtype_in`` keys to be converted to + ``dtype_out`` before being exposed to external objects and functions. + out_keys (sequence of NestedKey, optional): list of destination keys. + Defaults to ``in_keys`` if not provided. + in_keys_inv (sequence of NestedKey, optional): list of ``dtype_out`` keys to be converted to + ``dtype_in`` before being passed to the contained base_env or storage. + out_keys_inv (sequence of NestedKey, optional): list of destination keys for inverse + transform. + Defaults to ``in_keys_inv`` if not provided. Examples: >>> td = TensorDict( - ... {'obs': torch.ones(1, dtype=torch.double)}, []) - >>> transform = DoubleToFloat(in_keys=["obs"]) + ... {'obs': torch.ones(1, dtype=torch.double), + ... 'not_transformed': torch.ones(1, dtype=torch.double), + ... }, []) + >>> transform = DTypeCastTransform(torch.double, torch.float, in_keys=["obs"]) + >>> _ = transform(td) + >>> print(td.get("obs").dtype) + torch.float32 + >>> print(td.get("not_transformed").dtype) + torch.float64 + + In "automatic" mode, all float64 entries are transformed: + + Examples: + >>> td = TensorDict( + ... {'obs': torch.ones(1, dtype=torch.double), + ... 'not_transformed': torch.ones(1, dtype=torch.double), + ... }, []) + >>> transform = DTypeCastTransform(torch.double, torch.float) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 + >>> print(td.get("not_transformed").dtype) + torch.float32 + + The same behaviour is the rule when environments are constructedw without + specifying the transform keys: + + Examples: + >>> class MyEnv(EnvBase): + ... def __init__(self): + ... super().__init__() + ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) + ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) + ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) + ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... def _reset(self, data=None): + ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) + ... def _step(self, data): + ... assert data["action"].dtype == torch.float64 + ... reward = self.reward_spec.rand() + ... done = torch.zeros((1,), dtype=torch.bool) + ... obs = self.observation_spec.rand() + ... assert reward.dtype == torch.float64 + ... assert obs["obs"].dtype == torch.float64 + ... return obs.select().set("next", obs.update({"reward": reward, "done": done})) + ... def _set_seed(self, seed): + ... pass + >>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float)) + >>> assert env.action_spec.dtype == torch.float32 + >>> assert env.observation_spec["obs"].dtype == torch.float32 + >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype + >>> print(env.rollout(2)) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False) + >>> assert env.transform.in_keys == ["obs", "reward"] + >>> assert env.transform.in_keys_inv == ["action"] """ @@ -2371,59 +2924,239 @@ class DoubleToFloat(Transform): def __init__( self, + dtype_in: torch.dtype, + dtype_out: torch.dtype, in_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, + out_keys_inv: Optional[Sequence[NestedKey]] = None, ): - super().__init__(in_keys=in_keys, in_keys_inv=in_keys_inv) + self.dtype_in = dtype_in + self.dtype_out = dtype_out + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ) + + @property + def in_keys(self): + in_keys = self.__dict__.get("_in_keys", None) + if in_keys is None: + parent = self.parent + if parent is None: + # in_keys=None means all entries of dtype_in will be mapped to dtype_out + return None + in_keys = [] + for key, spec in parent.observation_spec.items(True, True): + if spec.dtype == self.dtype_in: + in_keys.append(unravel_key(key)) + for key, spec in parent.full_reward_spec.items(True, True): + if spec.dtype == self.dtype_in: + in_keys.append(unravel_key(key)) + self._in_keys = in_keys + if self.__dict__.get("_out_keys", None) is None: + self.out_keys = copy(in_keys) + return in_keys + + @in_keys.setter + def in_keys(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._in_keys = value + + @property + def out_keys(self): + out_keys = self.__dict__.get("_out_keys", None) + if out_keys is None: + out_keys = self._out_keys = copy(self.in_keys) + return out_keys + + @out_keys.setter + def out_keys(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._out_keys = value + + @property + def in_keys_inv(self): + in_keys_inv = self.__dict__.get("_in_keys_inv", None) + if in_keys_inv is None: + parent = self.parent + if parent is None: + # in_keys_inv=None means all entries of dtype_out will be mapped to dtype_in + return None + in_keys_inv = [] + for key, spec in parent.full_action_spec.items(True, True): + if spec.dtype == self.dtype_in: + in_keys_inv.append(unravel_key(key)) + for key, spec in parent.full_state_spec.items(True, True): + if spec.dtype == self.dtype_in: + in_keys_inv.append(unravel_key(key)) + self._in_keys_inv = in_keys_inv + if self.__dict__.get("_out_keys_inv", None) is None: + self.out_keys_inv = copy(in_keys_inv) + return in_keys_inv + + @in_keys_inv.setter + def in_keys_inv(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._in_keys_inv = value + + @property + def out_keys_inv(self): + out_keys_inv = self.__dict__.get("_out_keys_inv", None) + if out_keys_inv is None: + out_keys_inv = self._out_keys_inv = copy(self.in_keys_inv) + return out_keys_inv + + @out_keys_inv.setter + def out_keys_inv(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._out_keys_inv = value + + @dispatch(source="in_keys", dest="out_keys") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Reads the input tensordict, and for the selected keys, applies the transform.""" + in_keys = self.in_keys + out_keys = self.out_keys + if in_keys is None: + if out_keys is not None: + raise ValueError( + "in_keys wasn't provided and couldn't be retrieved. However, " + "out_keys was passed to the constructor. Since the order of the " + "entries mapped from dtype_in to dtype_out cannot be guaranteed, " + "this functionality is not covered. Consider passing the in_keys " + "or not passing any out_keys." + ) + for in_key, item in list(tensordict.items(True, True)): + if item.dtype == self.dtype_in: + item = self._apply_transform(item) + tensordict.set(in_key, item) + else: + # we made sure that if in_keys is not None, out_keys is not None either + for in_key, out_key in zip(in_keys, out_keys): + item = self._apply_transform(tensordict.get(in_key)) + tensordict.set(out_key, item) + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + in_keys_inv = self.in_keys_inv + out_keys_inv = self.out_keys_inv + if in_keys_inv is None: + if out_keys_inv is not None: + raise ValueError( + "in_keys_inv wasn't provided and couldn't be retrieved. However, " + "out_keys_inv was passed to the constructor. Since the order of the " + "entries mapped from dtype_in to dtype_out cannot be guaranteed, " + "this functionality is not covered. Consider passing the in_keys_inv " + "or not passing any out_keys_inv." + ) + for in_key_inv, item in list(tensordict.items(True, True)): + if item.dtype == self.dtype_out: + item = self._inv_apply_transform(item) + tensordict.set(in_key_inv, item) + return tensordict + else: + return super()._inv_call(tensordict) def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: - return obs.to(torch.float) + return obs.to(self.dtype_out) - def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: - return obs.to(torch.double) + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + return state.to(self.dtype_in) def _transform_spec(self, spec: TensorSpec) -> None: if isinstance(spec, CompositeSpec): for key in spec: self._transform_spec(spec[key]) else: - spec.dtype = torch.float + spec = spec.clone() + spec.dtype = self.dtype_out space = spec.space if isinstance(space, ContinuousBox): - space.minimum = space.minimum.to(torch.float) - space.maximum = space.maximum.to(torch.float) + space.low = space.low.to(self.dtype_out) + space.high = space.high.to(self.dtype_out) + return spec def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - action_spec = input_spec["_action_spec"] - state_spec = input_spec["_state_spec"] - for key in self.in_keys_inv: - if key in action_spec.keys(True): - _spec = action_spec - elif state_spec is not None and key in state_spec.keys(True): - _spec = state_spec + full_action_spec = input_spec["full_action_spec"] + full_state_spec = input_spec["full_state_spec"] + # if this method is called, then it must have a parent and in_keys_inv will be defined + if self.in_keys_inv is None: + raise NotImplementedError( + f"Calling transform_input_spec without a parent environment isn't supported yet for {type(self)}." + ) + for in_key_inv, out_key_inv in zip(self.in_keys_inv, self.out_keys_inv): + if in_key_inv in full_action_spec.keys(True): + _spec = full_action_spec[in_key_inv] + target = "action" + elif in_key_inv in full_state_spec.keys(True): + _spec = full_state_spec[in_key_inv] + target = "state" else: - raise KeyError(f"Key {key} not found in state_spec and action_spec.") - if _spec[key].dtype is not torch.double: + raise KeyError( + f"Key {in_key_inv} not found in state_spec and action_spec." + ) + if _spec.dtype != self.dtype_in: raise TypeError( - f"input_spec[{key}].dtype is not double: {input_spec[key].dtype}" + f"input_spec[{in_key_inv}].dtype is not {self.dtype_in}: {in_key_inv.dtype}" ) - self._transform_spec(_spec[key]) + _spec = self._transform_spec(_spec) + if target == "action": + full_action_spec[out_key_inv] = _spec + elif target == "state": + full_state_spec[out_key_inv] = _spec + else: + # unreachable + raise RuntimeError return input_spec - @_apply_to_composite - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - reward_key = self.parent.reward_key if self.parent is not None else "reward" - if unravel_key(reward_key) in self.in_keys: - if reward_spec.dtype is not torch.double: - raise TypeError("reward_spec.dtype is not double") - - self._transform_spec(reward_spec) - return reward_spec + def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + if self.in_keys is None: + raise NotImplementedError( + f"Calling transform_reward_spec without a parent environment isn't supported yet for {type(self)}." + ) + full_reward_spec = output_spec["full_reward_spec"] + for reward_key, reward_spec in list(full_reward_spec.items(True, True)): + # find out_key that match the in_key + for in_key, out_key in zip(self.in_keys, self.out_keys): + if reward_key == in_key: + if reward_spec.dtype != self.dtype_in: + raise TypeError(f"reward_spec.dtype is not {self.dtype_in}") + full_reward_spec[out_key] = self._transform_spec(reward_spec) + output_spec["full_observation_spec"] = self.transform_observation_spec( + output_spec["full_observation_spec"] + ) + return output_spec - @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - self._transform_spec(observation_spec) - return observation_spec + def transform_observation_spec(self, observation_spec): + full_observation_spec = observation_spec + for observation_key, observation_spec in list( + full_observation_spec.items(True, True) + ): + # find out_key that match the in_key + for in_key, out_key in zip(self.in_keys, self.out_keys): + if observation_key == in_key: + if observation_spec.dtype != self.dtype_in: + raise TypeError( + f"observation_spec.dtype is not {self.dtype_in}" + ) + full_observation_spec[out_key] = self._transform_spec( + observation_spec + ) + return full_observation_spec def __repr__(self) -> str: s = ( @@ -2433,6 +3166,216 @@ def __repr__(self) -> str: return s +class DoubleToFloat(DTypeCastTransform): + """Casts one dtype to another for selected keys. + + Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided + during construction, the class behaviour will change: + + * If the keys are provided, those entries and those entries only will be + transformed from ``float64`` to ``float32`` entries; + * If the keys are not provided and the object is within an environment + register of transforms, the input and output specs that have a dtype + set to ``float64`` will be used as in_keys_inv / in_keys respectively. + * If the keys are not provided and the object is used without an + environment, the ``forward`` / ``inverse`` pass will scan through the + input tensordict for all float64 values and map them to a float32 + tensor. For large data structures, this can impact performance as this + scanning doesn't come for free. The keys to be + transformed will not be cached. + Note that, in this case, the out_keys (resp. + out_keys_inv) cannot be passed as the order on which the keys are processed + cannot be anticipated precisely. + + Args: + in_keys (sequence of NestedKey, optional): list of double keys to be converted to + float before being exposed to external objects and functions. + out_keys (sequence of NestedKey, optional): list of destination keys. + Defaults to ``in_keys`` if not provided. + in_keys_inv (sequence of NestedKey, optional): list of float keys to be converted to + double before being passed to the contained base_env or storage. + out_keys_inv (sequence of NestedKey, optional): list of destination keys for inverse + transform. + Defaults to ``in_keys_inv`` if not provided. + + Examples: + >>> td = TensorDict( + ... {'obs': torch.ones(1, dtype=torch.double), + ... 'not_transformed': torch.ones(1, dtype=torch.double), + ... }, []) + >>> transform = DoubleToFloat(in_keys=["obs"]) + >>> _ = transform(td) + >>> print(td.get("obs").dtype) + torch.float32 + >>> print(td.get("not_transformed").dtype) + torch.float64 + + In "automatic" mode, all float64 entries are transformed: + + Examples: + >>> td = TensorDict( + ... {'obs': torch.ones(1, dtype=torch.double), + ... 'not_transformed': torch.ones(1, dtype=torch.double), + ... }, []) + >>> transform = DoubleToFloat() + >>> _ = transform(td) + >>> print(td.get("obs").dtype) + torch.float32 + >>> print(td.get("not_transformed").dtype) + torch.float32 + + The same behaviour is the rule when environments are constructedw without + specifying the transform keys: + + Examples: + >>> class MyEnv(EnvBase): + ... def __init__(self): + ... super().__init__() + ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) + ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) + ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) + ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... def _reset(self, data=None): + ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) + ... def _step(self, data): + ... assert data["action"].dtype == torch.float64 + ... reward = self.reward_spec.rand() + ... done = torch.zeros((1,), dtype=torch.bool) + ... obs = self.observation_spec.rand() + ... assert reward.dtype == torch.float64 + ... assert obs["obs"].dtype == torch.float64 + ... return obs.select().set("next", obs.update({"reward": reward, "done": done})) + ... def _set_seed(self, seed): + ... pass + >>> env = TransformedEnv(MyEnv(), DoubleToFloat()) + >>> assert env.action_spec.dtype == torch.float32 + >>> assert env.observation_spec["obs"].dtype == torch.float32 + >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype + >>> print(env.rollout(2)) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False) + >>> assert env.transform.in_keys == ["obs", "reward"] + >>> assert env.transform.in_keys_inv == ["action"] + + """ + + invertible = True + + def __init__( + self, + in_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Optional[Sequence[NestedKey]] = None, + in_keys_inv: Optional[Sequence[NestedKey]] = None, + out_keys_inv: Optional[Sequence[NestedKey]] = None, + ): + super().__init__( + dtype_in=torch.double, + dtype_out=torch.float, + in_keys=in_keys, + in_keys_inv=in_keys_inv, + out_keys=out_keys, + out_keys_inv=out_keys_inv, + ) + + +class DeviceCastTransform(Transform): + """Moves data from one device to another. + + Args: + device (torch.device or equivalent): the destination device. + orig_device (torch.device or equivalent): the origin device. If not specified and + a parent environment exists, it it retrieved from it. In all other cases, + it remains unspecified. + + Examples: + >>> td = TensorDict( + ... {'obs': torch.ones(1, dtype=torch.double), + ... }, [], device="cpu:0") + >>> transform = DeviceCastTransform(device=torch.device("cpu:2")) + >>> td = transform(td) + >>> print(td.device) + cpu:2 + + """ + + invertible = True + + def __init__( + self, + device, + orig_device=None, + ): + self.device = torch.device(device) + self.orig_device = ( + torch.device(orig_device) if orig_device is not None else orig_device + ) + super().__init__() + + def set_container(self, container: Union[Transform, EnvBase]) -> None: + if self.orig_device is None: + if isinstance(container, EnvBase): + device = container.device + else: + parent = container.parent + if parent is not None: + device = parent.device + else: + device = torch.device("cpu") + self.orig_device = device + return super().set_container(container) + + @dispatch(source="in_keys", dest="out_keys") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict.to(self.device, non_blocking=True) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict.to(self.device, non_blocking=True) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + parent = self.parent + if parent is None: + if self.orig_device is None: + return tensordict + return tensordict.to(self.orig_device, non_blocking=True) + return tensordict.to(parent.device, non_blocking=True) + + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + return input_spec.to(self.device) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + return reward_spec.to(self.device) + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + return observation_spec.to(self.device) + + def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + return output_spec.to(self.device) + + def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + return done_spec.to(self.device) + + def transform_env_device(self, device): + return self.device + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(device={self.device}, orig_device={self.orig_device})" + return s + + class CatTensors(Transform): """Concatenates several keys in a single tensor. @@ -2490,7 +3433,6 @@ def __init__( in_keys = sorted(in_keys, key=_sort_keys) if not isinstance(out_key, (str, tuple)): raise Exception("CatTensors requires out_key to be of type NestedKey") - # super().__init__(in_keys=in_keys) super(CatTensors, self).__init__(in_keys=in_keys, out_keys=[out_key]) self.dim = dim self._del_keys = del_keys @@ -2652,11 +3594,13 @@ def __init__( in_keys = in_keys_inv else: in_keys = [] + if in_keys_inv is None: + in_keys_inv = [] super().__init__( in_keys=in_keys, - out_keys=in_keys, + out_keys=copy(in_keys), in_keys_inv=in_keys_inv, - out_keys_inv=in_keys_inv, + out_keys_inv=copy(in_keys_inv), ) self.num_actions_effective = num_actions_effective self.max_actions = max_actions @@ -2682,7 +3626,7 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: raise RuntimeError( f"action.shape[-1]={action.shape[-1]} must match self.max_actions={self.max_actions}." ) - action = action.argmax(-1) # bool to int + action = action.long().argmax(-1) # bool to int idx = action >= self.num_actions_effective if idx.any(): action[idx] = torch.randint(self.num_actions_effective, (idx.sum(),)) @@ -2691,8 +3635,8 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: def transform_input_spec(self, input_spec: CompositeSpec): input_spec = input_spec.clone() - for key in input_spec["_action_spec"].keys(True, True): - key = ("_action_spec", key) + for key in input_spec["full_action_spec"].keys(True, True): + key = ("full_action_spec", key) break else: raise KeyError("key not found in action_spec.") @@ -2724,21 +3668,23 @@ class FrameSkipTransform(Transform): """ def __init__(self, frame_skip: int = 1): - super().__init__([]) + super().__init__() if frame_skip < 1: raise ValueError("frame_skip should have a value greater or equal to one.") self.frame_skip = frame_skip - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: parent = self.parent if parent is None: raise RuntimeError("parent not found for FrameSkipTransform") reward_key = parent.reward_key - reward = tensordict.get(("next", reward_key)) + reward = next_tensordict.get(reward_key) for _ in range(self.frame_skip - 1): - tensordict = parent._step(tensordict) - reward = reward + tensordict.get(("next", reward_key)) - return tensordict.set(("next", reward_key), reward) + next_tensordict = parent._step(tensordict) + reward = reward + next_tensordict.get(reward_key) + return next_tensordict.set(reward_key, reward) def forward(self, tensordict): raise RuntimeError( @@ -2753,8 +3699,11 @@ class NoopResetEnv(Transform): env (EnvBase): env on which the random actions have to be performed. Can be the same env as the one provided to the TransformedEnv class - noops (int, optional): number of actions performed after reset. - Default is `30`. + noops (int, optional): upper-bound on the number of actions + performed after reset. Default is `30`. + If noops is too high such that it results in the env being + done or truncated before the all the noops are applied, + in multiple trials, the transform raises a RuntimeError. random (bool, optional): if False, the number of random ops will always be equal to the noops value. If True, the number of random actions will be randomly selected between 0 and noops. @@ -2763,11 +3712,8 @@ class NoopResetEnv(Transform): """ def __init__(self, noops: int = 30, random: bool = True): - """Sample initial states by taking random number of no-ops on reset. - - No-op is assumed to be action 0. - """ - super().__init__([]) + """Sample initial states by taking random number of no-ops on reset.""" + super().__init__() self.noops = noops self.random = random @@ -2785,7 +3731,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: raise RuntimeError( "NoopResetEnv.parent not found. Make sure that the parent is set." ) - done_key = parent.done_key + done_keys = parent.done_keys reward_key = parent.reward_key if parent.batch_size.numel() > 1: raise ValueError( @@ -2801,31 +3747,40 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - trial = 0 - while True: + trial = 0 + while trial <= _MAX_NOOPS_TRIALS: i = 0 + while i < noops: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get(done_key): + reset = False + # if any of the done_keys is True, we break + for done_key in done_keys: + done = tensordict.get(done_key) + if done.numel() > 1: + raise ValueError( + f"{type(self)} only supports scalar done states." + ) + if done: + reset = True + break + if reset: tensordict = parent.reset(td_reset.clone(False)) break else: break trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.rand_step(tensordict) - if tensordict.get(("next", done_key)): - raise RuntimeError( - f"parent is still done after a single random step (i={i})." - ) - break - if tensordict.get(done_key): - raise RuntimeError("NoopResetEnv concluded with done environment") + else: + raise RuntimeError( + f"Parent env was repeatedly done or truncated" + f" before the sampled number of noops (={noops}) could be applied. " + ) + return tensordict.exclude(reward_key, inplace=True) def __repr__(self) -> str: @@ -2921,7 +3876,7 @@ def __init__(self, primers: dict = None, random=False, default_value=0.0, **kwar "The values of the primers must be a subtype of the TensorSpec class. " f"Got {type(spec)} instead." ) - super().__init__([]) + super().__init__() @property def device(self): @@ -2986,10 +3941,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(key, value) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: for key in self.primers.keys(): - tensordict.setdefault(("next", key), tensordict.get(key, default=None)) - return tensordict + next_tensordict.setdefault(key, tensordict.get(key, default=None)) + return next_tensordict def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Sets the default values in the input tensordict. @@ -3024,7 +3981,7 @@ class PinMemoryTransform(Transform): """Calls pin_memory on the tensordict to facilitate writing on CUDA devices.""" def __init__(self): - super().__init__([]) + super().__init__() def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict.pin_memory() @@ -3081,6 +4038,8 @@ class VecNorm(Transform): Args: in_keys (sequence of NestedKey, optional): keys to be updated. default: ["observation", "reward"] + out_keys (sequence of NestedKey, optional): destination keys. + Defaults to ``in_keys``. shared_td (TensorDictBase, optional): A shared tensordict containing the keys of the transform. decay (number, optional): decay rate of the moving average. @@ -3117,6 +4076,7 @@ class VecNorm(Transform): def __init__( self, in_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Optional[Sequence[NestedKey]] = None, shared_td: Optional[TensorDictBase] = None, lock: mp.Lock = None, decay: float = 0.9999, @@ -3127,7 +4087,9 @@ def __init__( lock = mp.Lock() if in_keys is None: in_keys = ["observation", "reward"] - super().__init__(in_keys) + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) self._td = shared_td if shared_td is not None and not ( shared_td.is_shared() or shared_td.is_memmap() @@ -3375,129 +4337,221 @@ def __repr__(self) -> str: f"eps={self.eps:4.4f}, keys={self.in_keys})" ) + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + _lock = state.pop("lock", None) + if _lock is not None: + state["lock_placeholder"] = None + return state + + def __setstate__(self, state: Dict[str, Any]): + if "lock_placeholder" in state: + state.pop("lock_placeholder") + _lock = mp.Lock() + state["lock"] = _lock + self.__dict__.update(state) + class RewardSum(Transform): """Tracks episode cumulative rewards. This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative - value along each episode. When called, the transform creates a new tensordict key for each in_key named - ´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env - reward and be present in the env reward_spec. + value along the time dimension for each episode. + + When called, the transform writes a new tensordict entry for each ``in_key`` named + ``episode_{in_key}`` where the cumulative values are written. + + Args: + in_keys (list of NestedKeys, optional): Input reward keys. + All ´in_keys´ should be part of the environment reward_spec. + If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key. + However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified. + out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key. + reset_keys (list of NestedKeys, optional): the list of reset_keys to be + used, if the parent environment cannot be found. If provided, this + value will prevail over the environment ``reset_keys``. - If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards - (e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict, - this transform hos no effect. + Examples: + >>> from torchrl.envs.transforms import RewardSum, TransformedEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), RewardSum()) + >>> td = env.reset() + >>> print(td["episode_reward"]) + tensor([0.]) + >>> td = env.rollout(3) + >>> print(td["next", "episode_reward"]) + tensor([[-0.5926], + [-1.4578], + [-2.7885]]) """ def __init__( self, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, + reset_keys: Optional[Sequence[NestedKey]] = None, ): """Initialises the transform. Filters out non-reward input keys and defines output keys.""" - if in_keys is None: - in_keys = ["reward"] - if out_keys is None and in_keys == ["reward"]: - out_keys = ["episode_reward"] - elif out_keys is None: - raise RuntimeError( - "the out_keys must be specified for non-conventional in-keys in RewardSum." + super().__init__(in_keys=in_keys, out_keys=out_keys) + self._reset_keys = reset_keys + + @property + def in_keys(self): + in_keys = self.__dict__.get("_in_keys", None) + if in_keys in (None, []): + # retrieve rewards from parent env + parent = self.parent + if parent is None: + in_keys = ["reward"] + else: + in_keys = copy(parent.reward_keys) + self._in_keys = in_keys + return in_keys + + @in_keys.setter + def in_keys(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._in_keys = value + + @property + def out_keys(self): + out_keys = self.__dict__.get("_out_keys", None) + if out_keys in (None, []): + out_keys = [ + _replace_last(in_key, f"episode_{_unravel_key_to_tuple(in_key)[-1]}") + for in_key in self.in_keys + ] + self._out_keys = out_keys + return out_keys + + @out_keys.setter + def out_keys(self, value): + # we must access the private attribute because this check occurs before + # the parent env is defined + if value is not None and len(self._in_keys) != len(value): + raise ValueError( + "RewardSum expects the same number of input and output keys" ) + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._out_keys = value - super().__init__(in_keys=in_keys, out_keys=out_keys) + @property + def reset_keys(self): + reset_keys = self.__dict__.get("_reset_keys", None) + if reset_keys is None: + parent = self.parent + if parent is None: + raise TypeError( + "reset_keys not provided but parent env not found. " + "Make sure that the reset_keys are provided during " + "construction if the transform does not have a container env." + ) + reset_keys = copy(parent.reset_keys) + self._reset_keys = reset_keys + return reset_keys + + @reset_keys.setter + def reset_keys(self, value): + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(val) for val in value] + self._reset_keys = value def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Resets episode rewards.""" - # Non-batched environments - _reset = tensordict.get("_reset", None) - if _reset is None: - _reset = torch.ones( - self.parent.done_spec.shape if self.parent else tensordict.batch_size, - dtype=torch.bool, - device=tensordict.device, - ) - if _reset.any(): - reward_key = self.parent.reward_key if self.parent else "reward" - for in_key, out_key in zip(self.in_keys, self.out_keys): - if out_key in tensordict.keys(True, True): - value = tensordict[out_key] - tensordict[out_key] = value.masked_fill( - expand_as_right(_reset, value), 0.0 - ) - elif unravel_key(in_key) == unravel_key(reward_key): + for in_key, reset_key, out_key in zip( + self.in_keys, self.reset_keys, self.out_keys + ): + _reset = tensordict.get(reset_key, None) + + if _reset is None or _reset.any(): + value = tensordict.get(out_key, default=None) + if value is not None: + if _reset is None: + tensordict.set(out_key, torch.zeros_like(value)) + else: + tensordict.set( + out_key, + value.masked_fill( + expand_as_right(_reset.squeeze(-1), value), 0.0 + ), + ) + else: # Since the episode reward is not in the tensordict, we need to allocate it # with zeros entirely (regardless of the _reset mask) - tensordict[out_key] = self.parent.reward_spec.zero() - else: - try: - tensordict[out_key] = self.parent.observation_spec[ - in_key - ].zero() - except KeyError as err: - raise KeyError( - f"The key {in_key} was not found in the parent " - f"observation_spec with keys " - f"{list(self.parent.observation_spec.keys(True))}. " - ) from err + tensordict.set( + out_key, + self.parent.full_reward_spec[in_key].zero(), + ) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: """Updates the episode rewards with the step rewards.""" # Update episode rewards - next_tensordict = tensordict.get("next") for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key in next_tensordict.keys(include_nested=True): reward = next_tensordict.get(in_key) - if out_key not in tensordict.keys(True): - tensordict.set(out_key, torch.zeros_like(reward)) - next_tensordict.set(out_key, tensordict.get(out_key) + reward) + prev_reward = tensordict.get(out_key, 0.0) + next_tensordict.set(out_key, prev_reward + reward) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") - tensordict.set("next", next_tensordict) - return tensordict + return next_tensordict - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - """Transforms the observation spec, adding the new keys generated by RewardSum.""" - # Retrieve parent reward spec - reward_spec = self.parent.reward_spec - reward_key = self.parent.reward_key if self.parent else "reward" - - episode_specs = {} - if isinstance(reward_spec, CompositeSpec): - # If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec - if not all(k in reward_spec.keys(True, True) for k in self.in_keys): - raise KeyError("Not all in_keys are present in ´reward_spec´") - - # Define episode specs for all out_keys - for out_key in self.out_keys: - episode_spec = UnboundedContinuousTensorSpec( - shape=reward_spec.shape, - device=reward_spec.device, - dtype=reward_spec.dtype, - ) - episode_specs.update({out_key: episode_spec}) + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + state_spec = input_spec["full_state_spec"] + if state_spec is None: + state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) + state_spec.update(self._generate_episode_reward_spec()) + input_spec["full_state_spec"] = state_spec + return input_spec - else: - # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´ - if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}: - raise KeyError( - "reward_spec is not a CompositeSpec class, in_keys should only include ´reward´" + def _generate_episode_reward_spec(self) -> CompositeSpec: + episode_reward_spec = CompositeSpec() + reward_spec = self.parent.full_reward_spec + reward_spec_keys = self.parent.reward_keys + # Define episode specs for all out_keys + for in_key, out_key in zip(self.in_keys, self.out_keys): + if ( + in_key in reward_spec_keys + ): # if this out_key has a corresponding key in reward_spec + out_key = _unravel_key_to_tuple(out_key) + temp_episode_reward_spec = episode_reward_spec + temp_rew_spec = reward_spec + for sub_key in out_key[:-1]: + if ( + not isinstance(temp_rew_spec, CompositeSpec) + or sub_key not in temp_rew_spec.keys() + ): + break + if sub_key not in temp_episode_reward_spec.keys(): + temp_episode_reward_spec[sub_key] = temp_rew_spec[ + sub_key + ].empty() + temp_rew_spec = temp_rew_spec[sub_key] + temp_episode_reward_spec = temp_episode_reward_spec[sub_key] + episode_reward_spec[out_key] = reward_spec[in_key].clone() + else: + raise ValueError( + f"The in_key: {in_key} is not present in the reward spec {reward_spec}." ) + return episode_reward_spec - # Define episode spec - episode_spec = UnboundedContinuousTensorSpec( - device=reward_spec.device, - dtype=reward_spec.dtype, - shape=reward_spec.shape, - ) - episode_specs.update({"episode_reward": episode_spec}) - - # Update observation_spec with episode_specs + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + """Transforms the observation spec, adding the new keys generated by RewardSum.""" if not isinstance(observation_spec, CompositeSpec): observation_spec = CompositeSpec( observation=observation_spec, shape=self.parent.batch_size ) - observation_spec.update(episode_specs) + observation_spec.update(self._generate_episode_reward_spec()) return observation_spec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -3515,20 +4569,78 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class StepCounter(Transform): - """Counts the steps from a reset and sets the done state to True after a certain number of steps. + """Counts the steps from a reset and optionally sets the truncated state to ``True`` after a certain number of steps. + + The ``"done"`` state is also adaptec accordingly (as done is the intersection + of task completetion and early truncation). Args: max_steps (int, optional): a positive integer that indicates the maximum number of steps to take before setting the ``truncated_key`` entry to ``True``. - However, the step count will still be - incremented on each call to step() into the `step_count` attribute. - truncated_key (NestedKey, optional): the key where the truncated key should - be written. Defaults to ``"truncated"``, which is recognised by + truncated_key (str, optional): the key where the truncated entries + should be written. Defaults to ``"truncated"``, which is recognised by data collectors as a reset signal. - step_count_key (NestedKey, optional): the key where the step_count key should - be written. Defaults to ``"step_count"``, which is recognised by - data collectors. + This argument can only be a string (not a nested key) as it will be + matched to each of the leaf done key in the parent environment + (eg, a ``("agent", "done")`` key will be accompanied by a + ``("agent", "truncated")`` if the ``"truncated"`` key name is used). + step_count_key (str, optional): the key where the step count entries + should be written. Defaults to ``"step_count"``. + This argument can only be a string (not a nested key) as it will be + matched to each of the leaf done key in the parent environment + (eg, a ``("agent", "done")`` key will be accompanied by a + ``("agent", "step_count")`` if the ``"step_count"`` key name is used). + update_done (bool, optional): if ``True``, the ``"done"`` boolean tensor + at the level of ``"truncated"`` + will be updated. + This signal indicates that the trajectory has reached its ends, + either because the task is completed (``"completed"`` entry is + ``True``) or because it has been truncated (``"truncated"`` entry + is ``True``). + Defaults to ``True``. + + .. note:: To ensure compatibility with environments that have multiple + done_key(s), this transform will write a step_count entry for + every done entry within the tensordict. + + Examples: + >>> import gymnasium + >>> from torchrl.envs import GymWrapper + >>> base_env = GymWrapper(gymnasium.make("Pendulum-v1")) + >>> env = TransformedEnv(base_env, + ... StepCounter(max_steps=5)) + >>> rollout = env.rollout(100) + >>> print(rollout) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + completed: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + completed: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> print(rollout["next", "step_count"]) + tensor([[1], + [2], + [3], + [4], + [5]]) + """ invertible = False @@ -3536,61 +4648,174 @@ class StepCounter(Transform): def __init__( self, max_steps: Optional[int] = None, - truncated_key: Optional[NestedKey] = "truncated", - step_count_key: Optional[NestedKey] = "step_count", + truncated_key: str | None = "truncated", + step_count_key: str | None = "step_count", + update_done: bool = True, ): if max_steps is not None and max_steps < 1: raise ValueError("max_steps should have a value greater or equal to one.") + if not isinstance(truncated_key, str): + raise ValueError("truncated_key must be a string.") + if not isinstance(step_count_key, str): + raise ValueError("step_count_key must be a string.") self.max_steps = max_steps self.truncated_key = truncated_key self.step_count_key = step_count_key - super().__init__([]) + self.update_done = update_done + super().__init__() - def reset(self, tensordict: TensorDictBase) -> TensorDictBase: - done_key = self.parent.done_key if self.parent else "done" - done = tensordict.get(done_key, None) - if done is None: - done = torch.ones( - self.parent.done_spec.shape, - dtype=self.parent.done_spec.dtype, - device=self.parent.done_spec.device, - ) - _reset = tensordict.get( - "_reset", - # TODO: decide if using done here, or using a default `True` tensor - default=None, - ) - if _reset is None: - _reset = torch.ones_like(done) - step_count = tensordict.get( - self.step_count_key, - default=None, - ) - if step_count is None: - step_count = torch.zeros_like(done, dtype=torch.int64) + @property + def truncated_keys(self): + truncated_keys = self.__dict__.get("_truncated_keys", None) + if truncated_keys is None: + # make the default truncated keys + truncated_keys = [] + for (done_key, *_) in self.parent.done_keys_groups: + if isinstance(done_key, str): + key = self.truncated_key + else: + key = (*done_key[:-1], self.truncated_key) + truncated_keys.append(key) + self._truncated_keys = truncated_keys + return truncated_keys - step_count[_reset] = 0 - tensordict.set( - self.step_count_key, - step_count, - ) - if self.max_steps is not None: - truncated = step_count >= self.max_steps - tensordict.set(self.truncated_key, truncated) - return tensordict + @property + def completed_keys(self): + done_keys = self.__dict__.get("_done_keys", None) + if done_keys is None: + # make the default done keys + done_keys = [] + for (done_key, *_) in self.parent.done_keys_groups: + if isinstance(done_key, str): + key = "done" + else: + key = (*done_key[:-1], "done") + done_keys.append(key) + self.__dict__["_done_keys"] = done_keys + return done_keys - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = tensordict.clone(False) - step_count = tensordict.get( - self.step_count_key, - ) - next_step_count = step_count + 1 - tensordict.set(("next", self.step_count_key), next_step_count) - if self.max_steps is not None: - truncated = next_step_count >= self.max_steps - tensordict.set(("next", self.truncated_key), truncated) + @property + def done_keys(self): + done_keys = self.__dict__.get("_done_keys", None) + if done_keys is None: + # make the default done keys + done_keys = [] + for (done_key, *_) in self.parent.done_keys_groups: + if isinstance(done_key, str): + key = "done" + else: + key = (*done_key[:-1], "done") + done_keys.append(key) + self.__dict__["_done_keys"] = done_keys + return done_keys + + @property + def terminated_keys(self): + terminated_keys = self.__dict__.get("_terminated_keys", None) + if terminated_keys is None: + # make the default terminated keys + terminated_keys = [] + for (terminated_key, *_) in self.parent.done_keys_groups: + if isinstance(terminated_key, str): + key = "terminated" + else: + key = (*terminated_key[:-1], "terminated") + terminated_keys.append(key) + self.__dict__["_terminated_keys"] = terminated_keys + return terminated_keys + + @property + def step_count_keys(self): + step_count_keys = self.__dict__.get("_step_count_keys", None) + if step_count_keys is None: + # make the default step_count keys + step_count_keys = [] + for (done_key, *_) in self.parent.done_keys_groups: + if isinstance(done_key, str): + key = self.step_count_key + else: + key = (*done_key[:-1], self.step_count_key) + step_count_keys.append(key) + self.__dict__["_step_count_keys"] = step_count_keys + return step_count_keys + + @property + def reset_keys(self): + if self.parent is not None: + return self.parent.reset_keys + # fallback on default "_reset" + return ["_reset"] + + @property + def done_keys_groups(self): + if self.parent is not None: + return self.parent.done_keys_groups + return [["done", "truncated"]] + + @property + def full_done_spec(self): + return self.parent.output_spec["full_done_spec"] if self.parent else None + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + # get reset signal + for step_count_key, truncated_key, reset_key, done_key, done_list_sorted in zip( + self.step_count_keys, + self.truncated_keys, + self.reset_keys, + self.done_keys, + self.done_keys_groups, + ): + step_count = tensordict.get(step_count_key, default=None) + reset = tensordict.get(reset_key, default=None) + if reset is None: + # get done status, just to inform the reset shape, dtype and device + for entry_name in done_list_sorted: + done = tensordict.get(entry_name, default=None) + if done is not None: + break + else: + # It may be the case that reset did not provide a done state, in which case + # we fall back on the spec + done = self.parent.output_spec["full_done_spec", entry_name].zero() + reset = torch.ones_like(done) + if step_count is None: + step_count = self.container.observation_spec[step_count_key].zero() + + # zero the step count if reset is needed + step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) + tensordict.set(step_count_key, step_count) + if self.max_steps is not None: + truncated = step_count >= self.max_steps + if self.update_done: + # we assume no done after reset + tensordict.set(done_key, truncated) + tensordict.set(truncated_key, truncated) return tensordict + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + for step_count_key, truncated_key, done_key, terminated_key in zip( + self.step_count_keys, + self.truncated_keys, + self.done_keys, + self.terminated_keys, + ): + step_count = tensordict.get(step_count_key) + next_step_count = step_count + 1 + next_tensordict.set(step_count_key, next_step_count) + if self.max_steps is not None: + truncated = next_step_count >= self.max_steps + if self.update_done: + done = next_tensordict.get(done_key, None) + terminated = next_tensordict.get(terminated_key, None) + if terminated is not None: + truncated = truncated & ~terminated + done = truncated | done # we assume no done after reset + next_tensordict.set(done_key, done) + next_tensordict.set(truncated_key, truncated) + return next_tensordict + def transform_observation_spec( self, observation_spec: CompositeSpec ) -> CompositeSpec: @@ -3598,36 +4823,126 @@ def transform_observation_spec( raise ValueError( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) - observation_spec[self.step_count_key] = UnboundedDiscreteTensorSpec( - shape=self.parent.done_spec.shape - if self.parent - else observation_spec.shape, - dtype=torch.int64, - device=observation_spec.device, - ) - observation_spec[self.step_count_key].space.minimum = ( - observation_spec[self.step_count_key].space.minimum * 0 - ) - if self.max_steps is not None and self.truncated_key != self.parent.done_key: - observation_spec[self.truncated_key] = self.parent.done_spec.clone() - return observation_spec + full_done_spec = self.parent.output_spec["full_done_spec"] + for step_count_key in self.step_count_keys: + step_count_key = unravel_key(step_count_key) + # find a matching done key (there might be more than one) + for done_key in self.done_keys: + # check root + if type(done_key) != type(step_count_key): + continue + if isinstance(done_key, tuple): + if done_key[:-1] == step_count_key[:-1]: + shape = full_done_spec[done_key].shape + break + if isinstance(done_key, str): + shape = full_done_spec[done_key].shape + break + + else: + raise KeyError( + f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}." + ) + observation_spec[step_count_key] = BoundedTensorSpec( + shape=shape, + dtype=torch.int64, + device=observation_spec.device, + low=0, + high=torch.iinfo(torch.int64).max, + ) + return super().transform_observation_spec(observation_spec) + + def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + if self.max_steps: + full_done_spec = self.parent.output_spec["full_done_spec"] + for truncated_key in self.truncated_keys: + truncated_key = unravel_key(truncated_key) + # find a matching done key (there might be more than one) + for done_key in self.done_keys: + # check root + if type(done_key) != type(truncated_key): + continue + if isinstance(done_key, tuple): + if done_key[:-1] == truncated_key[:-1]: + shape = full_done_spec[done_key].shape + break + if isinstance(done_key, str): + shape = full_done_spec[done_key].shape + break + + else: + raise KeyError( + f"Could not find root of truncated_key {truncated_key} in done keys {self.done_keys}." + ) + full_done_spec[truncated_key] = DiscreteTensorSpec( + 2, dtype=torch.bool, device=output_spec.device, shape=shape + ) + if self.update_done: + for done_key in self.done_keys: + done_key = unravel_key(done_key) + # find a matching done key (there might be more than one) + for done_key in self.done_keys: + # check root + if type(done_key) != type(done_key): + continue + if isinstance(done_key, tuple): + if done_key[:-1] == done_key[:-1]: + shape = full_done_spec[done_key].shape + break + if isinstance(done_key, str): + shape = full_done_spec[done_key].shape + break + + else: + raise KeyError( + f"Could not find root of stop_key {done_key} in done keys {self.done_keys}." + ) + full_done_spec[done_key] = DiscreteTensorSpec( + 2, dtype=torch.bool, device=output_spec.device, shape=shape + ) + output_spec["full_done_spec"] = full_done_spec + return super().transform_output_spec(output_spec) def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: if not isinstance(input_spec, CompositeSpec): raise ValueError( f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead." ) - if input_spec["_state_spec"] is None: - input_spec["_state_spec"] = CompositeSpec( + if input_spec["full_state_spec"] is None: + input_spec["full_state_spec"] = CompositeSpec( shape=input_spec.shape, device=input_spec.device ) - step_spec = UnboundedDiscreteTensorSpec( - shape=self.parent.done_spec.shape if self.parent else input_spec.shape, - dtype=torch.int64, - device=input_spec.device, - ) - step_spec.space.minimum *= 0 - input_spec["_state_spec", self.step_count_key] = step_spec + + full_done_spec = self.parent.output_spec["full_done_spec"] + for step_count_key in self.step_count_keys: + step_count_key = unravel_key(step_count_key) + # find a matching done key (there might be more than one) + for done_key in self.done_keys: + # check root + if type(done_key) != type(step_count_key): + continue + if isinstance(done_key, tuple): + if done_key[:-1] == step_count_key[:-1]: + shape = full_done_spec[done_key].shape + break + if isinstance(done_key, str): + shape = full_done_spec[done_key].shape + break + + else: + raise KeyError( + f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}." + ) + + input_spec[ + unravel_key(("full_state_spec", step_count_key)) + ] = BoundedTensorSpec( + shape=shape, + dtype=torch.int64, + device=input_spec.device, + low=0, + high=torch.iinfo(torch.int64).max, + ) return input_spec @@ -3643,16 +4958,41 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class ExcludeTransform(Transform): - """Excludes keys from the input tensordict. + """Excludes keys from the data. Args: *excluded_keys (iterable of NestedKey): The name of the keys to exclude. If the key is not present, it is simply ignored. + Examples: + >>> import gymnasium + >>> from torchrl.envs import GymWrapper + >>> env = TransformedEnv( + ... GymWrapper(gymnasium.make("Pendulum-v1")), + ... ExcludeTransform("truncated") + ... ) + >>> env.rollout(3) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + """ def __init__(self, *excluded_keys): - super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[]) + super().__init__() try: excluded_keys = unravel_key_list(excluded_keys) except TypeError: @@ -3660,8 +5000,6 @@ def __init__(self, *excluded_keys): "excluded keys must be a list or tuple of strings or tuples of strings." ) self.excluded_keys = excluded_keys - if "reward" in excluded_keys: - raise RuntimeError("'reward' cannot be excluded from the keys.") def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict.exclude(*self.excluded_keys) @@ -3671,17 +5009,25 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def reset(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict.exclude(*self.excluded_keys) - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if any(key in observation_spec.keys(True, True) for key in self.excluded_keys): - return CompositeSpec( - { - key: value - for key, value in observation_spec.items() - if unravel_key(key) not in self.excluded_keys - }, - shape=observation_spec.shape, - ) - return observation_spec + def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + full_done_spec = output_spec["full_done_spec"] + full_reward_spec = output_spec["full_reward_spec"] + full_observation_spec = output_spec["full_observation_spec"] + for key in self.excluded_keys: + # done_spec + if unravel_key(key) in list(full_done_spec.keys(True, True)): + del full_done_spec[key] + continue + # reward_spec + if unravel_key(key) in list(full_reward_spec.keys(True, True)): + del full_reward_spec[key] + continue + # observation_spec + if unravel_key(key) in list(full_observation_spec.keys(True, True)): + del full_observation_spec[key] + continue + raise KeyError(f"Key {key} not found in the environment outputs.") + return output_spec class SelectTransform(Transform): @@ -3695,10 +5041,40 @@ class SelectTransform(Transform): *selected_keys (iterable of NestedKey): The name of the keys to select. If the key is not present, it is simply ignored. + Keyword Args: + keep_rewards (bool, optional): if ``False``, the reward keys must be provided + if they should be kept. Defaults to ``True``. + keep_dones (bool, optional): if ``False``, the done keys must be provided + if they should be kept. Defaults to ``True``. + + >>> import gymnasium + >>> from torchrl.envs import GymWrapper + >>> env = TransformedEnv( + ... GymWrapper(gymnasium.make("Pendulum-v1")), + ... SelectTransform("observation", "reward", "done", keep_dones=False), # we leave done behind + ... ) + >>> env.rollout(3) # the truncated key is now absent + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + """ - def __init__(self, *selected_keys): - super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[]) + def __init__(self, *selected_keys, keep_rewards=True, keep_dones=True): + super().__init__() try: selected_keys = unravel_key_list(selected_keys) except TypeError: @@ -3706,40 +5082,64 @@ def __init__(self, *selected_keys): "selected keys must be a list or tuple of strings or tuples of strings." ) self.selected_keys = selected_keys + self.keep_done_keys = keep_dones + self.keep_reward_keys = keep_rewards def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.parent: + if self.parent is not None: input_keys = self.parent.input_spec.keys(True, True) else: input_keys = [] - reward_key = self.parent.reward_key if self.parent else "reward" - done_key = self.parent.done_key if self.parent else "done" + if self.keep_reward_keys: + reward_keys = self.parent.reward_keys if self.parent else ["reward"] + else: + reward_keys = [] + if self.keep_done_keys: + done_keys = self.parent.done_keys if self.parent else ["done"] + else: + done_keys = [] return tensordict.select( - *self.selected_keys, reward_key, done_key, *input_keys, strict=False + *self.selected_keys, *reward_keys, *done_keys, *input_keys, strict=False ) forward = _call def reset(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.parent: + if self.parent is not None: input_keys = self.parent.input_spec.keys(True, True) else: input_keys = [] - reward_key = self.parent.reward_key if self.parent else "reward" - done_key = self.parent.done_key if self.parent else "done" + if self.keep_reward_keys: + reward_keys = self.parent.reward_keys if self.parent else ["reward"] + else: + reward_keys = [] + if self.keep_done_keys: + done_keys = self.parent.done_keys if self.parent else ["done"] + else: + done_keys = [] return tensordict.select( - *self.selected_keys, reward_key, done_key, *input_keys, strict=False + *self.selected_keys, *reward_keys, *done_keys, *input_keys, strict=False ) - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - return CompositeSpec( - { - key: value - for key, value in observation_spec.items() - if unravel_key(key) in self.selected_keys - }, - shape=observation_spec.shape, - ) + def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + full_done_spec = output_spec["full_done_spec"] + full_reward_spec = output_spec["full_reward_spec"] + full_observation_spec = output_spec["full_observation_spec"] + if not self.keep_done_keys: + for key in list(full_done_spec.keys(True, True)): + if unravel_key(key) not in self.selected_keys: + del full_done_spec[key] + + for key in list(full_observation_spec.keys(True, True)): + if unravel_key(key) not in self.selected_keys: + del full_observation_spec[key] + + if not self.keep_reward_keys: + for key in list(full_reward_spec.keys(True, True)): + if unravel_key(key) not in self.selected_keys: + del full_reward_spec[key] + + return output_spec class TimeMaxPool(Transform): @@ -3751,6 +5151,30 @@ class TimeMaxPool(Transform): in_keys (sequence of NestedKey, optional): input keys on which the max pool will be applied. Defaults to "observation" if left empty. out_keys (sequence of NestedKey, optional): output keys where the output will be written. Defaults to `in_keys` if left empty. T (int, optional): Number of time steps over which to apply max pooling. + + Examples: + >>> from torchrl.envs import GymEnv + >>> base_env = GymEnv("Pendulum-v1") + >>> env = TransformedEnv(base_env, TimeMaxPool(in_keys=["observation"], T=10)) + >>> torch.manual_seed(0) + >>> env.set_seed(0) + >>> rollout = env.rollout(10) + >>> print(rollout["observation"]) # values should be increasing up until the 10th step + tensor([[ 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0216, 0.0000], + [ 0.0000, 0.1149, 0.0000], + [ 0.0000, 0.1990, 0.0000], + [ 0.0000, 0.2749, 0.0000], + [ 0.0000, 0.3281, 0.0000], + [-0.9290, 0.3702, -0.8978]]) + + .. note:: :class:`~TimeMaxPool` currently only supports ``done`` signal at the root. + Nested ``done``, such as those found in MARL settings, are currently not supported. + If this feature is needed, please raise an issue on TorchRL repo. + """ invertible = False @@ -3763,14 +5187,16 @@ def __init__( ): if in_keys is None: in_keys = ["observation"] + if out_keys is None: + out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) if T < 1: raise ValueError( - "TimeMaxPoolTranform T parameter should have a value greater or equal to one." + "TimeMaxPoolTransform T parameter should have a value greater or equal to one." ) if len(self.in_keys) != len(self.out_keys): raise ValueError( - "TimeMaxPoolTranform in_keys and out_keys don't have the same number of elements" + "TimeMaxPoolTransform in_keys and out_keys don't have the same number of elements" ) self.buffer_size = T for in_key in self.in_keys: @@ -3784,7 +5210,6 @@ def __init__( ) def reset(self, tensordict: TensorDictBase) -> TensorDictBase: - """Resets _buffers.""" # Non-batched environments if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1: for in_key in self.in_keys: @@ -3904,7 +5329,7 @@ def __init__( ) self.sample_dim = sample_dim self.mask_key = mask_key - super().__init__([]) + super().__init__() def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = tensordict.shape @@ -3980,46 +5405,112 @@ class InitTracker(Transform): """ def __init__(self, init_key: NestedKey = "is_init"): - super().__init__(in_keys=[], out_keys=[init_key]) + if not isinstance(init_key, str): + raise ValueError("init_key can only be of type str.") + self.init_key = init_key + self.reset_key = "_reset" + super().__init__() - def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.out_keys[0] not in tensordict.keys(True, True): - device = tensordict.device - if device is None: - device = torch.device("cpu") - tensordict.set( - self.out_keys[0], - torch.zeros( - self.parent.done_spec.shape, device=device, dtype=torch.bool - ), + def set_container(self, container: Union[Transform, EnvBase]) -> None: + self._init_keys = None + return super().set_container(container) + + @property + def out_keys(self): + return self.init_keys + + @out_keys.setter + def out_keys(self, value): + if value in (None, []): + return + raise ValueError( + "Cannot set non-empty out-keys when out-keys are defined by the init_key value." + ) + + @property + def init_keys(self): + init_keys = self.__dict__.get("_init_keys", None) + if init_keys is not None: + return init_keys + init_keys = [] + if self.parent is None: + raise NotImplementedError( + FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__) ) + for done_key, *_ in self.parent.done_keys_groups: + if isinstance(done_key, str): + init_key = self.init_key + else: + init_key = unravel_key((*done_key[:-1], self.init_key)) + init_keys.append(init_key) + self._init_keys = init_keys + return self._init_keys + + @property + def reset_keys(self): + return self.parent.reset_keys + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + for init_key, (done_key, *_) in zip( + self.init_keys, self.parent.done_keys_groups + ): + if init_key not in tensordict.keys(True, True): + device = tensordict.device + if device is None: + device = torch.device("cpu") + shape = self.parent.full_done_spec[done_key].shape + tensordict.set( + init_key, + torch.zeros(shape, device=device, dtype=torch.bool), + ) return tensordict def reset(self, tensordict: TensorDictBase) -> TensorDictBase: device = tensordict.device if device is None: device = torch.device("cpu") - _reset = tensordict.get("_reset", None) - if _reset is None: - tensordict.set( - self.out_keys[0], - torch.ones( - self.parent.done_spec.shape, - device=device, - dtype=torch.bool, - ), - ) - else: - tensordict.set(self.out_keys[0], _reset.clone()) + for reset_key, init_key, (done_key, *_) in zip( + self.reset_keys, self.init_keys, self.parent.done_keys_groups + ): + _reset = tensordict.get(reset_key, None) + if _reset is None: + shape = self.parent.full_done_spec[done_key].shape + tensordict.set( + init_key, + torch.ones( + shape, + device=device, + dtype=torch.bool, + ), + ) + else: + tensordict.set(init_key, _reset.clone()) return tensordict def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - observation_spec[self.out_keys[0]] = DiscreteTensorSpec( - 2, - dtype=torch.bool, - device=self.parent.device, - shape=self.parent.done_spec.shape, - ) + full_done_spec = self.parent.output_spec["full_done_spec"] + for init_key in self.init_keys: + for done_key in self.parent.done_keys: + # check root + if type(done_key) != type(init_key): + continue + if isinstance(done_key, tuple): + if done_key[:-1] == init_key[:-1]: + shape = full_done_spec[done_key].shape + break + if isinstance(done_key, str): + shape = full_done_spec[done_key].shape + break + else: + raise KeyError( + f"Could not find root of init_key {init_key} within done_keys {self.parent.done_keys}." + ) + observation_spec[init_key] = DiscreteTensorSpec( + 2, + dtype=torch.bool, + device=self.parent.device, + shape=shape, + ) return observation_spec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -4082,16 +5573,6 @@ class RenameTransform(Transform): def __init__( self, in_keys, out_keys, in_keys_inv=None, out_keys_inv=None, create_copy=False ): - if "done" in in_keys and not create_copy: - raise ValueError( - "Renaming 'done' is not allowed. Set `create_copy` to `True` " - "to create a copy of the done state." - ) - if "reward" in in_keys and not create_copy: - raise ValueError( - "Renaming 'reward' is not allowed. Set `create_copy` to `True` " - "to create a copy of the reward entry." - ) if in_keys_inv is None: in_keys_inv = [] if out_keys_inv is None: @@ -4115,13 +5596,21 @@ def __init__( def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.create_copy: - out = tensordict.select(*self.in_keys) + out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance) for in_key, out_key in zip(self.in_keys, self.out_keys): - out.rename_key_(in_key, out_key) + try: + tensordict.rename_key_(in_key, out_key) + except KeyError: + if not self._missing_tolerance: + raise tensordict = tensordict.update(out) else: for in_key, out_key in zip(self.in_keys, self.out_keys): - tensordict.rename_key_(in_key, out_key) + try: + tensordict.rename_key_(in_key, out_key) + except KeyError: + if not self._missing_tolerance: + raise return tensordict forward = _call @@ -4129,60 +5618,95 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: # no in-place modif if self.create_copy: - out = tensordict.select(*self.out_keys_inv) + out = tensordict.select( + *self.out_keys_inv, strict=not self._missing_tolerance + ) for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - out.rename_key_(out_key, in_key) + try: + out.rename_key_(out_key, in_key) + except KeyError: + if not self._missing_tolerance: + raise + tensordict = tensordict.update(out) else: for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - tensordict.rename_key_(out_key, in_key) + try: + tensordict.rename_key_(out_key, in_key) + except KeyError: + if not self._missing_tolerance: + raise return tensordict def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: - # we need to check whether there are special keys - output_spec = output_spec.clone() - if "done" in self.in_keys: - for i, out_key in enumerate(self.out_keys): # noqa: B007 - if self.in_keys[i] == "done": - break - else: - raise RuntimeError("Expected one key to be 'done'") - output_spec["_observation_spec"][out_key] = output_spec[ - "_done_spec" - ].clone() - if "reward" in self.in_keys: - for i, out_key in enumerate(self.out_keys): # noqa: B007 - if self.in_keys[i] == "reward": - break - else: - raise RuntimeError("Expected one key to be 'reward'") - output_spec["_observation_spec"][out_key] = output_spec[ - "_reward_spec" - ].clone() - for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in ("reward", "done"): - continue - if out_key in ("done", "reward"): - output_spec[out_key] = output_spec["_observation_spec"][in_key].clone() - else: - output_spec["_observation_spec"][out_key] = output_spec[ - "_observation_spec" - ][in_key].clone() - if not self.create_copy: - del output_spec["_observation_spec"][in_key] + for done_key in self.parent.done_keys: + if done_key in self.in_keys: + for i, out_key in enumerate(self.out_keys): # noqa: B007 + if self.in_keys[i] == done_key: + break + else: + # unreachable + raise RuntimeError + output_spec["full_done_spec"][out_key] = output_spec["full_done_spec"][ + done_key + ].clone() + if not self.create_copy: + del output_spec["full_done_spec"][done_key] + for reward_key in self.parent.reward_keys: + if reward_key in self.in_keys: + for i, out_key in enumerate(self.out_keys): # noqa: B007 + if self.in_keys[i] == reward_key: + break + else: + # unreachable + raise RuntimeError + output_spec["full_reward_spec"][out_key] = output_spec[ + "full_reward_spec" + ][reward_key].clone() + if not self.create_copy: + del output_spec["full_reward_spec"][reward_key] + for observation_key in self.parent.full_observation_spec.keys(True): + if observation_key in self.in_keys: + for i, out_key in enumerate(self.out_keys): # noqa: B007 + if self.in_keys[i] == observation_key: + break + else: + # unreachable + raise RuntimeError + output_spec["full_observation_spec"][out_key] = output_spec[ + "full_observation_spec" + ][observation_key].clone() + if not self.create_copy: + del output_spec["full_observation_spec"][observation_key] return output_spec def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: - # we need to check whether there are special keys - input_spec = input_spec.clone() - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - in_key = (in_key,) if not isinstance(in_key, tuple) else in_key - out_key = (out_key,) if not isinstance(out_key, tuple) else out_key - input_spec[("_state_spec", *out_key)] = input_spec[ - ("_state_spec", *in_key) - ].clone() - if not self.create_copy: - del input_spec[("_state_spec", *in_key)] + for action_key in self.parent.action_keys: + if action_key in self.in_keys: + for i, out_key in enumerate(self.out_keys): # noqa: B007 + if self.in_keys[i] == action_key: + break + else: + # unreachable + raise RuntimeError + input_spec["full_action_spec"][out_key] = input_spec[ + "full_action_spec" + ][action_key].clone() + if not self.create_copy: + del input_spec["full_action_spec"][action_key] + for state_key in self.parent.full_state_spec.keys(True): + if state_key in self.in_keys: + for i, out_key in enumerate(self.out_keys): # noqa: B007 + if self.in_keys[i] == state_key: + break + else: + # unreachable + raise RuntimeError + input_spec["full_state_spec"][out_key] = input_spec["full_state_spec"][ + state_key + ].clone() + if not self.create_copy: + del input_spec["full_state_spec"][state_key] return input_spec @@ -4194,11 +5718,14 @@ class Reward2GoTransform(Transform): and not to the collector. Args: + gamma (float or torch.Tensor): the discount factor. Defaults to 1.0. in_keys (sequence of NestedKey): the entries to rename. Defaults to ``("next", "reward")`` if none is provided. out_keys (sequence of NestedKey): the entries to rename. Defaults to the values of ``in_keys`` if none is provided. - gamma (float or torch.Tensor): the discount factor. Defaults to 1.0. + done_key (NestedKey): the done entry. Defaults to ``"done"``. + truncated_key (NestedKey): the truncated entry. Defaults to ``"truncated"``. + If no truncated entry is found, only the ``"done"`` will be used. Examples: >>> # Using this transform as part of a replay buffer @@ -4280,6 +5807,9 @@ class Reward2GoTransform(Transform): >>> t = Reward2GoTransform(gamma=0.99) >>> TransformedEnv(GymEnv("Pendulum-v1"), t) # crashes + .. note:: In settings where multiple done entries are present, one should build + a single :class:`~Reward2GoTransform` for each done-reward pair. + """ ENV_ERR = ( @@ -4292,17 +5822,19 @@ def __init__( gamma: Optional[Union[float, torch.Tensor]] = 1.0, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, + done_key: Optional[NestedKey] = "done", ): if in_keys is None: in_keys = [("next", "reward")] if out_keys is None: - out_keys = deepcopy(in_keys) + out_keys = copy(in_keys) # out_keys = ["reward_to_go"] super().__init__( in_keys=in_keys, in_keys_inv=in_keys, out_keys_inv=out_keys, ) + self.done_key = done_key if not isinstance(gamma, torch.Tensor): gamma = torch.tensor(gamma) @@ -4310,14 +5842,9 @@ def __init__( self.register_buffer("gamma", gamma) def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: - done_key = self.parent.done_key if self.parent else "done" - done = tensordict.get(("next", done_key)) - truncated = tensordict.get(("next", "truncated"), None) - if truncated is not None: - done_or_truncated = done | truncated - else: - done_or_truncated = done - if not done_or_truncated.any(-2).all(): + done = tensordict.get(("next", self.done_key)) + + if not done.any(-2).all(): raise RuntimeError( "No episode ends found to calculate the reward to go. Make sure that the number of frames_per_batch is larger than number of steps per episode." ) @@ -4325,13 +5852,8 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): if in_key in tensordict.keys(include_nested=True): found = True - item = self._inv_apply_transform( - tensordict.get(in_key), done_or_truncated - ) - tensordict.set( - out_key, - item, - ) + item = self._inv_apply_transform(tensordict.get(in_key), done) + tensordict.set(out_key, item) if not found: raise KeyError(f"Could not find any of the input keys {self.in_keys}.") return tensordict @@ -4350,3 +5872,267 @@ def _inv_apply_transform( def set_container(self, container): if isinstance(container, EnvBase) or container.parent is not None: raise ValueError(self.ENV_ERR) + + +class ActionMask(Transform): + """An adaptive action masker. + + This transform reads the mask from the input tensordict after the step is executed, + and adapts the mask of the one-hot / categorical action spec. + + .. note:: This transform will fail when used without an environment. + + Args: + action_key (NestedKey, optional): the key where the action tensor can be found. + Defaults to ``"action"``. + mask_key (NestedKey, optional): the key where the action mask can be found. + Defaults to ``"action_mask"``. + + Examples: + >>> import torch + >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec + >>> from torchrl.envs.transforms import ActionMask, TransformedEnv + >>> from torchrl.envs.common import EnvBase + >>> class MaskedEnv(EnvBase): + ... def __init__(self, *args, **kwargs): + ... super().__init__(*args, **kwargs) + ... self.action_spec = DiscreteTensorSpec(4) + ... self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)) + ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) + ... self.reward_spec = UnboundedContinuousTensorSpec(1) + ... + ... def _reset(self, data): + ... td = self.observation_spec.rand() + ... td.update(torch.ones_like(self.state_spec.rand())) + ... return td + ... + ... def _step(self, data): + ... td = self.observation_spec.rand() + ... mask = data.get("action_mask") + ... action = data.get("action") + ... mask = mask.scatter(-1, action.unsqueeze(-1), 0) + ... + ... td.set("action_mask", mask) + ... td.set("reward", self.reward_spec.rand()) + ... td.set("done", ~mask.any().view(1)) + ... return td + ... + ... def _set_seed(self, seed): + ... return seed + ... + >>> torch.manual_seed(0) + >>> base_env = MaskedEnv() + >>> env = TransformedEnv(base_env, ActionMask()) + >>> r = env.rollout(10) + >>> env = TransformedEnv(base_env, ActionMask()) + >>> r = env.rollout(10) + >>> r["action_mask"] + tensor([[ True, True, True, True], + [ True, True, False, True], + [ True, True, False, False], + [ True, False, False, False]]) + + """ + + ACCEPTED_SPECS = ( + OneHotDiscreteTensorSpec, + DiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, + MultiDiscreteTensorSpec, + ) + SPEC_TYPE_ERROR = "The action spec must be one of {}. Got {} instead." + + def __init__( + self, action_key: NestedKey = "action", mask_key: NestedKey = "action_mask" + ): + if not isinstance(action_key, (tuple, str)): + raise ValueError( + f"The action key must be a nested key. Got {type(action_key)} instead." + ) + if not isinstance(mask_key, (tuple, str)): + raise ValueError( + f"The mask key must be a nested key. Got {type(mask_key)} instead." + ) + super().__init__( + in_keys=[action_key, mask_key], out_keys=[], in_keys_inv=[], out_keys_inv=[] + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self))) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + parent = self.parent + if parent is None: + raise RuntimeError( + f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment." + ) + mask = tensordict.get(self.in_keys[1]) + action_spec = self.container.action_spec + if not isinstance(action_spec, self.ACCEPTED_SPECS): + raise ValueError( + self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec)) + ) + action_spec.update_mask(mask) + return tensordict + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + action_spec = self.container.action_spec + if not isinstance(action_spec, self.ACCEPTED_SPECS): + raise ValueError( + self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec)) + ) + action_spec.update_mask(tensordict.get(self.in_keys[1], None)) + return tensordict + + +class VecGymEnvTransform(Transform): + """A transform for GymWrapper subclasses that handles the auto-reset in a consistent way. + + Gym, gymnasium and SB3 provide vectorized (read, parallel or batched) environments + that are automatically reset. When this occurs, the actual observation resulting + from the action is saved within a key in the info. + The class :class:`torchrl.envs.libs.gym.terminal_obs_reader` reads that observation + and stores it in a ``"final"`` key within the output tensordict. + In turn, this transform reads that final data, swaps it with the observation + written in its place that results from the actual reset, and saves the + reset output in a private container. The resulting data truly reflects + the output of the step. + + This class works from gym 0.13 till the most recent gymnasium version. + + .. note:: Gym versions < 0.22 did not return the final observations. For these, + we simply fill the next observations with NaN (because it is lost) and + do the swap at the next step. + + Then, when calling `env.reset`, the saved data is written back where it belongs + (and the `reset` is a no-op). + + This transform is automatically appended to the gym env whenever the wrapper + is created with an async env. + + Args: + final_name (str, optional): the name of the final observation in the dict. + Defaults to `"final"`. + + .. note:: In general, this class should not be handled directly. It is + created whenever a vectorized environment is placed within a :class:`GymWrapper`. + + """ + + def __init__(self, final_name="final"): + self.final_name = final_name + super().__init__() + self._memo = {} + + def set_container(self, container: Union[Transform, EnvBase]) -> None: + out = super().set_container(container) + self._done_keys = None + self._obs_keys = None + return out + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + # save the final info + done = False + for done_key in self.done_keys: + # we assume dones can be broadcast + done = done | next_tensordict.get(done_key) + if done is False: + raise RuntimeError( + f"Could not find any done signal in tensordict:\n{tensordict}" + ) + self._memo["done"] = done + final = next_tensordict.pop(self.final_name, None) + # if anything's done, we need to swap the final obs + if done.any(): + done = done.squeeze(-1) + if final is not None: + saved_next = next_tensordict.select(*final.keys(True, True)).clone() + next_tensordict[done] = final[done] + else: + saved_next = next_tensordict.select(*self.obs_keys).clone() + for obs_key in self.obs_keys: + next_tensordict[obs_key][done] = torch.tensor(np.nan) + + self._memo["saved_next"] = saved_next + else: + self._memo["saved_next"] = None + return next_tensordict + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + done = self._memo.get("done", None) + reset = tensordict.get("_reset", done) + if done is not None: + done = done.view_as(reset) + if ( + reset is not done + and (reset != done).any() + and (not reset.all() or not reset.any()) + ): + raise RuntimeError( + "Cannot partially reset a gym(nasium) async env with a reset mask that does not match the done mask. " + f"Got reset={reset}\nand done={done}" + ) + # if not reset.any(), we don't need to do anything. + # if reset.all(), we don't either (bc GymWrapper will call a plain reset). + if reset is not None and reset.any() and not reset.all(): + saved_next = self._memo["saved_next"] + # reset = reset.view(tensordict.shape) + # we have a data container from the previous call to step + # that contains part of the observation we need. + # We can safely place them back in the reset result tensordict: + # in env.rollout(), the result of reset() is assumed to be just + # the td from previous step with updated values from reset. + # In our case, it will always be the case that all these values + # are properly set. + # collectors even take care of doing an extra masking so it's even + # safer. + tensordict.update(saved_next) + for done_key in self.done_keys: + # Make sure that all done are False + done = tensordict.get(done_key, None) + if done is not None: + done = done.clone().fill_(0) + else: + done = torch.zeros( + (*tensordict.batch_size, 1), + device=tensordict.device, + dtype=torch.bool, + ) + tensordict.set(done_key, done) + tensordict.pop(self.final_name, None) + return tensordict + + @property + def done_keys(self) -> List[NestedKey]: + keys = self.__dict__.get("_done_keys", None) + if keys is None: + keys = self.parent.done_keys + # we just want the "done" key + _done_keys = [] + for key in keys: + if not isinstance(key, tuple): + key = (key,) + if key[-1] == "done": + _done_keys.append(unravel_key(key)) + if not len(_done_keys): + raise RuntimeError("Could not find a 'done' key in the env specs.") + self._done_keys = _done_keys + return keys + + @property + def obs_keys(self) -> List[NestedKey]: + keys = self.__dict__.get("_obs_keys", None) + if keys is None: + keys = list(self.parent.observation_spec.keys(True, True)) + self._obs_keys = keys + return keys + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if self.final_name in observation_spec.keys(True): + del observation_spec[self.final_name] + return observation_spec + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self))) diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index d8461d86911..e971b848d9b 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -358,7 +358,7 @@ def _embed_goal(self, tensordict): if "goal_image" not in tensordict.keys(): raise KeyError( f"{self.__class__.__name__}.reset() requires a `'goal_image'` key to be " - f"present in the input tensordict." + f"present in the input tensordict. Got keys {list(tensordict.keys())}." ) tensordict_in = tensordict.select("goal_image").rename_key_( "goal_image", self.in_keys[0] @@ -369,21 +369,37 @@ def _embed_goal(self, tensordict): ) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: if "goal_embedding" not in tensordict.keys(): tensordict = self._embed_goal(tensordict) last_embedding_key = self.out_keys[0] last_embedding = tensordict.get(last_embedding_key, None) - tensordict = super()._step(tensordict) - cur_embedding = tensordict.get(("next", self.out_keys[0])) + next_tensordict = super()._step(tensordict, next_tensordict) + cur_embedding = next_tensordict.get(self.out_keys[0]) if last_embedding is not None: goal_embedding = tensordict["goal_embedding"] - reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - ( - -torch.norm(last_embedding - goal_embedding, dim=-1) + reward = -torch.linalg.norm(cur_embedding - goal_embedding, dim=-1) - ( + -torch.linalg.norm(last_embedding - goal_embedding, dim=-1) ) - tensordict.set(("next", "reward"), reward) - return tensordict + next_tensordict.set("reward", reward) + return next_tensordict def forward(self, tensordict): tensordict = super().forward(tensordict) return tensordict + + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + if "full_state_spec" in input_spec.keys(): + full_state_spec = input_spec["full_state_spec"] + else: + full_state_spec = CompositeSpec( + shape=input_spec.shape, device=input_spec.device + ) + # find the obs spec + in_key = self.in_keys[0] + spec = self.parent.output_spec["full_observation_spec"][in_key] + full_state_spec["goal_image"] = spec.clone() + input_spec["full_state_spec"] = full_state_spec + return super().transform_input_spec(input_spec) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 756d05652ca..cc1f05d3ffb 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,11 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib + import importlib.util +import os +import re +from enum import Enum +from typing import Dict, List, Union import torch -from tensordict import is_tensor_collection, unravel_key +from tensordict import is_tensor_collection, TensorDictBase, unravel_key from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -19,7 +25,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import LazyStackedTensorDict, NestedKey, TensorDictBase +from tensordict.tensordict import LazyStackedTensorDict, NestedKey __all__ = [ "exploration_mode", @@ -30,10 +36,23 @@ "check_env_specs", "step_mdp", "make_composite_from_td", + "MarlGroupMapType", + "check_marl_grouping", ] -from torchrl.data import CompositeSpec +from torchrl.data import CompositeSpec, TensorSpec +from torchrl.data.utils import check_no_exclusive_keys + +ACTION_MASK_ERROR = RuntimeError( + "An out-of-bounds actions has been provided to an env with an 'action_mask' output." + " If you are using a custom policy, make sure to take the action mask into account when computing the output." + " If you are using a default policy, please add the torchrl.envs.transforms.ActionMask transform to your environment." + "If you are using a ParallelEnv or another batched inventor, " + "make sure to add the transform to the ParallelEnv (and not to the sub-environments)." + " For more info on using action masks, see the docs at: " + "https://pytorch.org/rl/reference/envs.html#environments-with-masked-actions" +) def _convert_exploration_type(*, exploration_mode, exploration_type): @@ -54,9 +73,9 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_key: NestedKey = "reward", - done_key: NestedKey = "done", - action_key: NestedKey = "action", + reward_keys: Union[NestedKey, List[NestedKey]] = "reward", + done_keys: Union[NestedKey, List[NestedKey]] = "done", + action_keys: Union[NestedKey, List[NestedKey]] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. @@ -84,11 +103,11 @@ def step_mdp( be kept in the root tensordict (since it should not be present in the ``"next"`` entry). Default is ``True``. - reward_key (key, optional): the key where the reward is written. Defaults + reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults to "reward". - done_key (key, optional): the key where the done is written. Defaults + done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults to "done". - action_key (key, optional): the key where the action is written. Defaults + action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults to "action". Returns: @@ -171,9 +190,9 @@ def step_mdp( exclude_reward=exclude_reward, exclude_done=exclude_done, exclude_action=exclude_action, - reward_key=reward_key, - done_key=done_key, - action_key=action_key, + reward_keys=reward_keys, + done_keys=done_keys, + action_keys=action_keys, ) for td, ntd in zip(tensordict.tensordicts, next_tensordicts) ], @@ -184,17 +203,20 @@ def step_mdp( return next_tensordict return out - action_key = unravel_key(action_key) - done_key = unravel_key(done_key) - reward_key = unravel_key(reward_key) + if not isinstance(action_keys, list): + action_keys = [action_keys] + if not isinstance(done_keys, list): + done_keys = [done_keys] + if not isinstance(reward_keys, list): + reward_keys = [reward_keys] excluded = set() if exclude_reward: - excluded = {reward_key} + excluded = excluded.union(reward_keys) if exclude_done: - excluded = excluded.union({done_key}) + excluded = excluded.union(done_keys) if exclude_action: - excluded = excluded.union({action_key}) + excluded = excluded.union(action_keys) next_td = tensordict.get("next") out = next_td.empty() @@ -204,7 +226,8 @@ def step_mdp( if key != "next": _set(tensordict, out, key, total_key, excluded) elif not exclude_action: - _set_single_key(tensordict, out, action_key) + for action_key in action_keys: + _set_single_key(tensordict, out, action_key) for key in next_td.keys(): _set(next_td, out, key, total_key, excluded) if next_tensordict is not None: @@ -213,49 +236,76 @@ def step_mdp( return out -def _set_single_key(source, dest, key, clone=False): +def _set_single_key( + source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False +): # key should be already unraveled if isinstance(key, str): key = (key,) for k in key: - val = source.get(k) - if is_tensor_collection(val): - new_val = dest.get(k, None) - if new_val is None: - new_val = val.empty() - # dest.set(k, new_val) - dest._set_str(k, new_val, inplace=False, validated=True) - source = val - dest = new_val - else: - if clone: - val = val.clone() - # dest.set(k, val) - dest._set_str(k, val, inplace=False, validated=True) + try: + val = source._get_str(k, None) + if is_tensor_collection(val): + new_val = dest._get_str(k, None) + if new_val is None: + new_val = val.empty() + dest._set_str(k, new_val, inplace=False, validated=True) + source = val + dest = new_val + else: + if clone: + val = val.clone() + dest._set_str(k, val, inplace=False, validated=True) + # This is a temporary solution to understand if a key is heterogeneous + # while not having performance impact when the exception is not raised + except RuntimeError as err: + if re.match(r"Found more than one unique shape in the tensors", str(err)): + # this is a het key + for s_td, d_td in zip(source.tensordicts, dest.tensordicts): + _set_single_key(s_td, d_td, k, clone) + break + else: + raise err def _set(source, dest, key, total_key, excluded): total_key = total_key + (key,) non_empty = False if unravel_key(total_key) not in excluded: - val = source.get(key) - if is_tensor_collection(val): - new_val = dest.get(key, None) - if new_val is None: - new_val = val.empty() - non_empty_local = False - for subkey in val.keys(): - non_empty_local = ( - _set(val, new_val, subkey, total_key, excluded) or non_empty_local - ) - if non_empty_local: - # dest.set(key, new_val) - dest._set_str(key, new_val, inplace=False, validated=True) - non_empty = non_empty_local - else: - non_empty = True - # dest.set(key, val) - dest._set_str(key, val, inplace=False, validated=True) + try: + val = source.get(key) + if is_tensor_collection(val): + new_val = dest.get(key, None) + if new_val is None: + new_val = val.empty() + non_empty_local = False + for subkey in val.keys(): + non_empty_local = ( + _set(val, new_val, subkey, total_key, excluded) + or non_empty_local + ) + if non_empty_local: + # dest.set(key, new_val) + dest._set_str(key, new_val, inplace=False, validated=True) + non_empty = non_empty_local + else: + non_empty = True + # dest.set(key, val) + dest._set_str(key, val, inplace=False, validated=True) + # This is a temporary solution to understand if a key is heterogeneous + # while not having performance impact when the exception is not raised + except RuntimeError as err: + if re.match(r"Found more than one unique shape in the tensors", str(err)): + # this is a het key + non_empty_local = False + for s_td, d_td in zip(source.tensordicts, dest.tensordicts): + non_empty_local = ( + _set(s_td, d_td, key, total_key, excluded) or non_empty_local + ) + non_empty = non_empty_local + else: + raise err + return non_empty @@ -377,8 +427,9 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of an experiment and as such should be kept out of training scripts. """ - torch.manual_seed(seed) - env.set_seed(seed) + if seed is not None: + torch.manual_seed(seed) + env.set_seed(seed) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous) @@ -388,9 +439,22 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) else: fake_tensordict = torch.stack([fake_tensordict.clone() for _ in range(3)], -1) + # eliminate empty containers + fake_tensordict_select = fake_tensordict.select(*fake_tensordict.keys(True, True)) + real_tensordict_select = real_tensordict.select(*real_tensordict.keys(True, True)) + # check keys + fake_tensordict_keys = set(fake_tensordict.keys(True, True)) + real_tensordict_keys = set(real_tensordict.keys(True, True)) + if fake_tensordict_keys != real_tensordict_keys: + raise AssertionError( + f"""The keys of the specs and data do not match: + - List of keys present in real but not in fake: {real_tensordict_keys-fake_tensordict_keys}, + - List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}. +""" + ) if ( - fake_tensordict.apply(lambda x: torch.zeros_like(x)) - != real_tensordict.apply(lambda x: torch.zeros_like(x)) + fake_tensordict_select.apply(lambda x: torch.zeros_like(x)) + != real_tensordict_select.apply(lambda x: torch.zeros_like(x)) ).any(): raise AssertionError( "zeroing the two tensordicts did not make them identical. " @@ -398,21 +462,30 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): ) # Checks shapes and eventually dtypes of keys at all nesting levels - _per_level_env_check(fake_tensordict, real_tensordict, check_dtype=check_dtype) + _per_level_env_check( + fake_tensordict_select, real_tensordict_select, check_dtype=check_dtype + ) # Check specs last_td = real_tensordict[..., -1] - _action_spec = env.input_spec["_action_spec"] - _state_spec = env.input_spec["_state_spec"] - _obs_spec = env.output_spec["_observation_spec"] - _reward_spec = env.output_spec["_reward_spec"] - _done_spec = env.output_spec["_done_spec"] + last_td = env.rand_action(last_td) + full_action_spec = env.input_spec["full_action_spec"] + full_state_spec = env.input_spec["full_state_spec"] + full_observation_spec = env.output_spec["full_observation_spec"] + full_reward_spec = env.output_spec["full_reward_spec"] + full_done_spec = env.output_spec["full_done_spec"] for name, spec in ( - ("action", _action_spec), - ("state", _state_spec), - ("done", _done_spec), - ("obs", _obs_spec), + ("action", full_action_spec), + ("state", full_state_spec), + ("done", full_done_spec), + ("obs", full_observation_spec), ): + if not check_no_exclusive_keys(spec): + raise AssertionError( + "It appears you are using some LazyStackedCompositeSpecs with exclusive keys " + "(keys present in some but not all of the stacked specs). To use such heterogeneous specs, " + "you will need to first pass your stack through `torchrl.data.consolidate_spec`." + ) if spec is None: spec = CompositeSpec(shape=env.batch_size, device=env.device) td = last_td.select(*spec.keys(True, True), strict=True) @@ -421,9 +494,9 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): f"spec check failed at root for spec {name}={spec} and data {td}." ) for name, spec in ( - ("reward", _reward_spec), - ("done", _done_spec), - ("obs", _obs_spec), + ("reward", full_reward_spec), + ("done", full_done_spec), + ("obs", full_observation_spec), ): if spec is None: spec = CompositeSpec(shape=env.batch_size, device=env.device) @@ -453,19 +526,6 @@ def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int return tensor -class classproperty: - """A class-property object. - - Usage: Allows for iterators coded as properties. - """ - - def __init__(self, fget): - self.fget = fget - - def __get__(self, owner_self, owner_cls): - return self.fget(owner_cls) - - def _sort_keys(element): if isinstance(element, tuple): element = unravel_key(element) @@ -497,7 +557,7 @@ def make_composite_from_td(data): obs: UnboundedContinuousTensorSpec( shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous), reward: UnboundedContinuousTensorSpec( - shape=torch.Size([1]), space=ContinuousBox(minimum=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), maximum=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) + shape=torch.Size([1]), space=ContinuousBox(low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> assert (spec.zero() == data.zero_()).all() """ from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec @@ -509,10 +569,333 @@ def make_composite_from_td(data): key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + dtype=tensor.dtype, + device=tensor.device, + shape=tensor.shape if tensor.shape else [1], ) for key, tensor in data.items() }, shape=data.shape, ) return composite + + +@contextlib.contextmanager +def clear_mpi_env_vars(): + """Clears the MPI of environment variables. + + `from mpi4py import MPI` will call `MPI_Init` by default. + If the child process has MPI environment variables, MPI will think that the child process + is an MPI process just like the parent and do bad things such as hang. + + This context manager is a hacky way to clear those environment variables + temporarily such as when we are starting multiprocessing Processes. + + Yields: + Yields for the context manager + """ + removed_environment = {} + for k, v in list(os.environ.items()): + for prefix in ["OMPI_", "PMI_"]: + if k.startswith(prefix): + removed_environment[k] = v + del os.environ[k] + try: + yield + finally: + os.environ.update(removed_environment) + + +def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: + if isinstance(key, str): + return new_ending + else: + return key[:-1] + (new_ending,) + + +class MarlGroupMapType(Enum): + """Marl Group Map Type. + + As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment. + You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same + neural network. You can split agents in different groups where they are heterogenous or should be processed by + different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time. + + Otherwise, you can choose one of the premade grouping strategies from this class. + + - With ``group_map=MarlGroupMapType.ALL_IN_ONE_GROUP`` and + agents ``["agent_0", "agent_1", "agent_2", "agent_3"]``, + the tensordicts coming and going from your environment will look + something like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([4]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agents": ["agent_0", "agent_1", "agent_2", "agent_3]} + + - With ``group_map=MarlGroupMapType.ONE_GROUP_PER_AGENT`` and + agents ``["agent_0", "agent_1", "agent_2", "agent_3"]``, + the tensordicts coming and going from your environment will look + something like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agent_0: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + agent_1: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + agent_2: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + agent_3: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agent_0": ["agent_0"], "agent_1": ["agent_1"], "agent_2": ["agent_2"], "agent_3": ["agent_3"]} + """ + + ALL_IN_ONE_GROUP = 1 + ONE_GROUP_PER_AGENT = 2 + + def get_group_map(self, agent_names: List[str]): + if self == MarlGroupMapType.ALL_IN_ONE_GROUP: + return {"agents": agent_names} + elif self == MarlGroupMapType.ONE_GROUP_PER_AGENT: + return {agent_name: [agent_name] for agent_name in agent_names} + + +def check_marl_grouping(group_map: Dict[str, List[str]], agent_names: List[str]): + """Check MARL group map. + + Performs checks on the group map of a marl environment to assess its validity. + Raises an error in cas of an invalid group_map. + + Args: + group_map (Dict[str, List[str]]): the group map mapping group names to list of agent names in the group + agent_names (List[str]): a list of all the agent names in the environment4 + + Examples: + >>> from torchrl.envs.utils import MarlGroupMapType, check_marl_grouping + >>> agent_names = ["agent_0", "agent_1", "agent_2"] + >>> check_marl_grouping(MarlGroupMapType.ALL_IN_ONE_GROUP.get_group_map(agent_names), agent_names) + + """ + n_agents = len(agent_names) + if n_agents == 0: + raise ValueError("No agents passed") + if len(set(agent_names)) != n_agents: + raise ValueError("There are agents with the same name") + if len(group_map.keys()) > n_agents: + raise ValueError( + f"Number of groups {len(group_map.keys())} greater than number of agents {n_agents}" + ) + found_agents = {agent_name: False for agent_name in agent_names} + for group_name, group in group_map.items(): + if not len(group): + raise ValueError(f"Group {group_name} is empty") + for agent_name in group: + if agent_name not in found_agents: + raise ValueError(f"Agent {agent_name} not present in environment") + if not found_agents[agent_name]: + found_agents[agent_name] = True + else: + raise ValueError(f"Agent {agent_name} present more than once") + for agent_name, found in found_agents.items(): + if not found: + raise ValueError(f"Agent {agent_name} not found in any group") + + +def terminated_or_truncated( + data: TensorDictBase, + full_done_spec: TensorSpec | None = None, + key: str = "_reset", + write_full_false: bool = False, +) -> bool: + """Reads the done / terminated / truncated keys within a tensordict, and writes a new tensor where the values of both signals are aggregated. + + The modification occurs in-place within the TensorDict instance provided. + This function can be used to compute the `"_reset"` signals in batched + or multiagent settings, hence the default name of the output key. + + Args: + data (TensorDictBase): the input data, generally resulting from a call + to :meth:`~torchrl.envs.EnvBase.step`. + full_done_spec (TensorSpec, optional): the done_spec from the env, + indicating where the done leaves have to be found. + If not provided, the default + ``"done"``, ``"terminated"`` and ``"truncated"`` entries will be + searched for in the data. + key (NestedKey, optional): where the aggregated result should be written. + If ``None``, then the function will not write any key but just output + whether any of the done values was true. + .. note:: if a value is already present for the ``key`` entry, + the previous value will prevail and no update will be achieved. + write_full_false (bool, optional): if ``True``, the reset keys will be + written even if the output is ``False`` (ie, no done is ``True`` + in the provided data structure). + Defaults to ``False``. + + Returns: a boolean value indicating whether any of the done states found in the data + contained a ``True``. + + Examples: + >>> from torchrl.data.tensor_specs import DiscreteTensorSpec + >>> from tensordict import TensorDict + >>> spec = CompositeSpec( + ... done=DiscreteTensorSpec(2, dtype=torch.bool), + ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), + ... nested=CompositeSpec( + ... done=DiscreteTensorSpec(2, dtype=torch.bool), + ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), + ... ) + ... ) + >>> data = TensorDict({ + ... "done": True, "truncated": False, + ... "nested": {"done": False, "truncated": True}}, + ... batch_size=[] + ... ) + >>> data = terminated_or_truncated(data, spec) + >>> print(data["_reset"]) + tensor(True) + >>> print(data["nested", "_reset"]) + tensor(True) + """ + list_of_keys = [] + + def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()): + any_eot = False + aggregate = None + if full_done_spec is None: + for eot_key, item in data.items(): + if eot_key == "done": + done = data.get(eot_key, None) + if done is None: + done = torch.zeros( + (*data.shape, 1), dtype=torch.bool, device=data.device + ) + if aggregate is None: + aggregate = torch.tensor(False, device=done.device) + aggregate = aggregate | done + elif eot_key in ("terminated", "truncated"): + done = data.get(eot_key, None) + if done is None: + done = torch.zeros( + (*data.shape, 1), dtype=torch.bool, device=data.device + ) + if aggregate is None: + aggregate = torch.tensor(False, device=done.device) + aggregate = aggregate | done + elif isinstance(item, TensorDictBase): + any_eot = any_eot | inner_terminated_or_truncated( + data=item, + full_done_spec=None, + key=key, + curr_done_key=curr_done_key + (eot_key,), + ) + else: + for eot_key, item in full_done_spec.items(): + if isinstance(item, CompositeSpec): + any_eot = any_eot | inner_terminated_or_truncated( + data=data.get(eot_key), + full_done_spec=item, + key=key, + curr_done_key=curr_done_key + (eot_key,), + ) + else: + sop = data.get(eot_key, None) + if sop is None: + sop = torch.zeros( + (*data.shape, 1), dtype=torch.bool, device=data.device + ) + if aggregate is None: + aggregate = torch.tensor(False, device=sop.device) + aggregate = aggregate | sop + if aggregate is not None: + if key is not None: + data.set(key, aggregate) + list_of_keys.append(curr_done_key + (key,)) + any_eot = any_eot | aggregate.any() + return any_eot + + any_eot = inner_terminated_or_truncated(data, full_done_spec, key) + if not any_eot and not write_full_false: + # remove the list of reset keys + data.exclude(*list_of_keys, inplace=True) + return any_eot + + +PARTIAL_MISSING_ERR = "Some reset keys were present but not all. Either all the `'_reset'` entries must be present, or none." + + +def _aggregate_resets(data: TensorDictBase, reset_keys=None) -> torch.Tensor: + # goes through the tensordict and brings the _reset information to + # a boolean tensor of the shape of the tensordict. + batch_size = data.batch_size + n = len(batch_size) + + if reset_keys is not None: + reset = False + has_missing = None + for key in reset_keys: + local_reset = data.get(key, None) + if local_reset is None: + if has_missing is False: + raise ValueError(PARTIAL_MISSING_ERR) + has_missing = True + continue + elif has_missing: + raise ValueError(PARTIAL_MISSING_ERR) + has_missing = False + if local_reset.ndim > n: + local_reset = local_reset.flatten(n, local_reset.ndim - 1) + local_reset = local_reset.any(-1) + reset = reset | local_reset + if has_missing: + return torch.ones(batch_size, dtype=torch.bool, device=data.device) + return reset + + reset = torch.tensor(False, device=data.device) + + def skim_through(td, reset=reset): + for key in td.keys(): + if key == "_reset": + local_reset = td.get(key) + if local_reset.ndim > n: + local_reset = local_reset.flatten(n, local_reset.ndim - 1) + local_reset = local_reset.any(-1) + reset = reset | local_reset + # we need to check the entry class without getting the value, + # because some lazy tensordicts may prevent calls to items(). + # This introduces some slight overhead as when we encounter a + # tensordict item, we'll need to get it twice. + elif is_tensor_collection(td.entry_class(key)): + value = td.get(key) + reset = skim_through(value, reset=reset) + return reset + + reset = skim_through(data) + return reset diff --git a/torchrl/envs/vec_envs.py b/torchrl/envs/vec_envs.py new file mode 100644 index 00000000000..73dd159751c --- /dev/null +++ b/torchrl/envs/vec_envs.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import warnings + +warnings.warn("vec_env.py has moved to batch_envs.py.", category=DeprecationWarning) + +from .batched_envs import * # noqa: F403, F401 diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index ebb73bcedf6..16d621f2bec 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -8,8 +8,10 @@ distributions_maps, IndependentNormal, MaskedCategorical, + MaskedOneHotCategorical, NormalParamWrapper, OneHotCategorical, + ReparamGradientStrategy, TanhDelta, TanhNormal, TruncatedNormal, @@ -20,16 +22,20 @@ DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, + DecisionTransformer, DistributionalDQNnet, DreamerActor, + DTActor, DuelingCnnDQNet, LSTMNet, MLP, + MultiAgentConvNet, MultiAgentMLP, NoisyLazyLinear, NoisyLinear, ObsDecoder, ObsEncoder, + OnlineDTActor, QMixer, reset_noise, RSSMPosterior, @@ -44,10 +50,13 @@ ActorCriticWrapper, ActorValueOperator, AdditiveGaussianWrapper, + DecisionTransformerInferenceWrapper, DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, + EGreedyModule, EGreedyWrapper, + GRUModule, LMHeadActorValueOperator, LSTMModule, OrnsteinUhlenbeckProcessWrapper, diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index ab358adfb00..a3c5d0d4774 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -12,7 +12,13 @@ TanhNormal, TruncatedNormal, ) -from .discrete import __all__ as _all_discrete, MaskedCategorical, OneHotCategorical +from .discrete import ( + __all__ as _all_discrete, + MaskedCategorical, + MaskedOneHotCategorical, + OneHotCategorical, + ReparamGradientStrategy, +) distributions_maps = { distribution_class.lower(): eval(distribution_class) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 52d52e3113e..bb98b1412a8 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -3,7 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Sequence, Union +from enum import Enum +from functools import wraps +from typing import Any, Optional, Sequence, Union import torch import torch.distributions as D @@ -32,6 +34,25 @@ def rand_one_hot(values: torch.Tensor, do_softmax: bool = True) -> torch.Tensor: return out +class _one_hot_wrapper: + def __init__(self, parent_dist): + self.parent_dist = parent_dist + + def __call__(self, func): + @wraps(func) + def wrapped(_self, *args, **kwargs): + out = getattr(self.parent_dist, func.__name__)(_self, *args, **kwargs) + n = _self.num_samples + return torch.nn.functional.one_hot(out, n) + + return wrapped + + +class ReparamGradientStrategy(Enum): + PassThrough: Any = 1 + RelaxedOneHot: Any = 2 + + class OneHotCategorical(D.Categorical): """One-hot categorical distribution. @@ -41,6 +62,13 @@ class OneHotCategorical(D.Categorical): Args: logits (torch.Tensor): event log probabilities (unnormalized) probs (torch.Tensor): event probabilities + grad_method (ReparamGradientStrategy, optional): strategy to gather + reparameterized samples. + ``ReparamGradientStrategy.PassThrough`` will compute the sample gradients + by using the softmax valued log-probability as a proxy to the + samples gradients. + ``ReparamGradientStrategy.RelaxedOneHot`` will use + :class:`torch.distributions.RelaxedOneHot` to sample from the distribution. Examples: >>> torch.manual_seed(0) @@ -59,11 +87,14 @@ def __init__( self, logits: Optional[torch.Tensor] = None, probs: Optional[torch.Tensor] = None, + grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs, ) -> None: logits = _treat_categorical_params(logits) probs = _treat_categorical_params(probs) + self.grad_method = grad_method super().__init__(probs=probs, logits=logits, **kwargs) + self.num_samples = self._param.shape[-1] def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1)) @@ -75,14 +106,11 @@ def mode(self) -> torch.Tensor: else: return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + @_one_hot_wrapper(D.Categorical) def sample( self, sample_shape: Optional[Union[torch.Size, Sequence]] = None ) -> torch.Tensor: - if sample_shape is None: - sample_shape = torch.Size([]) - out = super().sample(sample_shape=sample_shape) - out = torch.nn.functional.one_hot(out, self.logits.shape[-1]).to(torch.long) - return out + ... def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Tensor: if sample_shape is None: @@ -93,12 +121,25 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten else: logits = None probs = self.probs - d = D.relaxed_categorical.RelaxedOneHotCategorical( - 1.0, probs=probs, logits=logits - ) - out = d.rsample(sample_shape) - out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype)) - return out + if self.grad_method == ReparamGradientStrategy.RelaxedOneHot: + d = D.relaxed_categorical.RelaxedOneHotCategorical( + 1.0, probs=probs, logits=logits + ) + out = d.rsample(sample_shape) + out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype)) + return out + elif self.grad_method == ReparamGradientStrategy.PassThrough: + if logits is not None: + probs = self.probs + else: + probs = torch.softmax(self.logits, dim=-1) + out = self.sample(sample_shape) + out = out + probs - probs.detach() + return out + else: + raise ValueError( + f"Unknown reparametrization strategy {self.reparam_strategy}." + ) class MaskedCategorical(D.Categorical): @@ -112,6 +153,8 @@ class MaskedCategorical(D.Categorical): probs (torch.Tensor): event probabilities. If provided, the probabilities corresponding to to masked items will be zeroed and the probability re-normalized along its last dimension. + + Keyword Args: mask (torch.Tensor): A boolean mask of the same shape as ``logits``/``probs`` where ``False`` entries are the ones to be masked. Alternatively, if ``sparse_mask`` is True, it represents the list of valid indices @@ -149,6 +192,7 @@ def __init__( self, logits: Optional[torch.Tensor] = None, probs: Optional[torch.Tensor] = None, + *, mask: torch.Tensor = None, indices: torch.Tensor = None, neg_inf: float = float("-inf"), @@ -174,6 +218,7 @@ def __init__( probs[~mask] = 0 probs = probs / probs.sum(-1, keepdim=True) logits = probs.log() + num_samples = logits.shape[-1] logits = self._mask_logits( logits, mask, @@ -186,28 +231,27 @@ def __init__( self._sparse_mask = sparse_mask self._padding_value = padding_value super().__init__(logits=logits) + self.num_samples = num_samples def sample( self, sample_shape: Optional[Union[torch.Size, Sequence[int]]] = None ) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size() + else: + sample_shape = torch.Size(sample_shape) ret = super().sample(sample_shape) if not self._sparse_mask: return ret size = ret.size() - # Python 3.7 doesn't support math.prod - # outer_dim = prod(sample_shape) - # inner_dim = prod(self._mask.size()[:-1]) - outer_dim = torch.empty(sample_shape, device="meta").numel() - inner_dim = self._mask.numel() // self._mask.size(-1) + outer_dim = sample_shape.numel() + inner_dim = self._mask.shape[:-1].numel() idx_3d = self._mask.expand(outer_dim, inner_dim, -1) ret = idx_3d.gather(dim=-1, index=ret.view(outer_dim, inner_dim, 1)) - return ret.view(size) + return ret.reshape(size) - # # # TODO: Improve performance here. def log_prob(self, value: torch.Tensor) -> torch.Tensor: if not self._sparse_mask: return super().log_prob(value) @@ -247,3 +291,166 @@ def _mask_logits( if padding_value is not None: logits.masked_fill_(padding_mask, neg_inf) return logits + + +class MaskedOneHotCategorical(MaskedCategorical): + """MaskedCategorical distribution. + + Reference: + https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical + + Args: + logits (torch.Tensor): event log probabilities (unnormalized) + probs (torch.Tensor): event probabilities. If provided, the probabilities + corresponding to to masked items will be zeroed and the probability + re-normalized along its last dimension. + + Keyword Args: + mask (torch.Tensor): A boolean mask of the same shape as ``logits``/``probs`` + where ``False`` entries are the ones to be masked. Alternatively, + if ``sparse_mask`` is True, it represents the list of valid indices + in the distribution. Exclusive with ``indices``. + indices (torch.Tensor): A dense index tensor representing which actions + must be taken into account. Exclusive with ``mask``. + neg_inf (float, optional): The log-probability value allocated to + invalid (out-of-mask) indices. Defaults to -inf. + padding_value: The padding value in the then mask tensor when + sparse_mask == True, the padding_value will be ignored. + grad_method (ReparamGradientStrategy, optional): strategy to gather + reparameterized samples. + ``ReparamGradientStrategy.PassThrough`` will compute the sample gradients + by using the softmax valued log-probability as a proxy to the + samples gradients. + ``ReparamGradientStrategy.RelaxedOneHot`` will use + :class:`torch.distributions.RelaxedOneHot` to sample from the distribution. + + >>> torch.manual_seed(0) + >>> logits = torch.randn(4) / 100 # almost equal probabilities + >>> mask = torch.tensor([True, False, True, True]) + >>> dist = MaskedOneHotCategorical(logits=logits, mask=mask) + >>> sample = dist.sample((10,)) + >>> print(sample) # no `1` in the sample + tensor([[0, 0, 1, 0], + [0, 0, 0, 1], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 0]]) + >>> print(dist.log_prob(sample)) + tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831, + -1.1203, -1.1203]) + >>> sample_non_valid = torch.zeros_like(sample) + >>> sample_non_valid[..., 1] = 1 + >>> print(dist.log_prob(sample_non_valid)) + tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) + >>> # with probabilities + >>> prob = torch.ones(10) + >>> prob = prob / prob.sum() + >>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked + >>> dist = MaskedOneHotCategorical(probs=prob, mask=mask) + >>> s = torch.arange(10) + >>> s = torch.nn.functional.one_hot(s, 10) + >>> print(dist.log_prob(s)) + tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, + -2.1972, -2.1972]) + """ + + def __init__( + self, + logits: Optional[torch.Tensor] = None, + probs: Optional[torch.Tensor] = None, + mask: torch.Tensor = None, + indices: torch.Tensor = None, + neg_inf: float = float("-inf"), + padding_value: Optional[int] = None, + grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, + ) -> None: + self.grad_method = grad_method + super().__init__( + logits=logits, + probs=probs, + mask=mask, + indices=indices, + neg_inf=neg_inf, + padding_value=padding_value, + ) + + @_one_hot_wrapper(MaskedCategorical) + def sample( + self, sample_shape: Optional[Union[torch.Size, Sequence[int]]] = None + ) -> torch.Tensor: + ... + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + return super().log_prob(value.argmax(dim=-1)) + + def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Tensor: + if sample_shape is None: + sample_shape = torch.Size([]) + if hasattr(self, "logits") and self.logits is not None: + logits = self.logits + probs = None + else: + logits = None + probs = self.probs + if self.grad_method == ReparamGradientStrategy.RelaxedOneHot: + if self._sparse_mask: + if probs is not None: + probs_extended = torch.full( + (*probs.shape[:-1], self.num_samples), + 0, + device=probs.device, + dtype=probs.dtype, + ) + probs_extended = torch.scatter( + probs_extended, -1, self._mask, probs + ) + logits_extended = None + else: + probs_extended = torch.full( + (*logits.shape[:-1], self.num_samples), + self.neg_inf, + device=logits.device, + dtype=logits.dtype, + ) + logits_extended = torch.scatter( + probs_extended, -1, self._mask, logits + ) + probs_extended = None + else: + probs_extended = probs + logits_extended = logits + + d = D.relaxed_categorical.RelaxedOneHotCategorical( + 1.0, probs=probs_extended, logits=logits_extended + ) + out = d.rsample(sample_shape) + out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype)) + return out + elif self.grad_method == ReparamGradientStrategy.PassThrough: + if logits is not None: + probs = self.probs + else: + probs = torch.softmax(self.logits, dim=-1) + if self._sparse_mask: + probs_extended = torch.full( + (*probs.shape[:-1], self.num_samples), + 0, + device=probs.device, + dtype=probs.dtype, + ) + probs_extended = torch.scatter(probs_extended, -1, self._mask, probs) + else: + probs_extended = probs + + out = self.sample(sample_shape) + out = out + probs_extended - probs_extended.detach() + return out + else: + raise ValueError( + f"Unknown reparametrization strategy {self.reparam_strategy}." + ) diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index 1dfde393709..59b95658ea5 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -87,7 +87,6 @@ def mean(self): def variance(self): return self._variance - @property def entropy(self): return self._entropy diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 9d5ca41fae3..267632c4fd9 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -46,7 +46,7 @@ def __init__(self, base_distribution, transforms, validate_args=None): transforms, ] elif isinstance(transforms, list): - raise ValueError("Mae a ComposeTransform first.") + raise ValueError("Make a ComposeTransform first.") else: raise ValueError( "transforms must be a Transform or list, but was {}".format(transforms) @@ -160,5 +160,4 @@ def safetanh(x, eps): # noqa: D103 def safeatanh(y, eps): # noqa: D103 lim = 1.0 - eps - y = y.clone() - return y.clamp(-lim, lim) + return y.clamp(-lim, lim).atanh() diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 8e5d0c2f9c9..01aa429a412 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -4,18 +4,23 @@ # LICENSE file in the root directory of this source tree. +from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior from .models import ( + Conv2dNet, + Conv3dNet, ConvNet, DdpgCnnActor, DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, DistributionalDQNnet, + DTActor, DuelingCnnDQNet, LSTMNet, MLP, + OnlineDTActor, ) -from .multiagent import MultiAgentMLP, QMixer, VDNMixer +from .multiagent import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py new file mode 100644 index 00000000000..8eb72f1f9ea --- /dev/null +++ b/torchrl/modules/models/decision_transformer.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import dataclasses + +import importlib +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +_has_transformers = importlib.util.find_spec("transformers") is not None + + +class DecisionTransformer(nn.Module): + """Online Decion Transformer. + + Desdescribed in https://arxiv.org/abs/2202.05607 . + + The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config. + default_config = { + ... "n_embd": 256, + ... "n_layer": 4, + ... "n_head": 4, + ... "n_inner": 1024, + ... "activation": "relu", + ... "n_positions": 1024, + ... "resid_pdrop": 0.1, + ... "attn_pdrop": 0.1, + } + + Args: + state_dim (int): dimension of the state space + action_dim (int): dimension of the action space + config (:obj:`~.DTConfig` or dict, optional): transformer architecture configuration, + used to create the GPT2Config from transformers. + Defaults to :obj:`~.default_config`. + + + Example: + >>> config = DecisionTransformer.default_config() + >>> config.n_embd = 128 + >>> print(config) + DTConfig(n_embd: 128, n_layer: 4, n_head: 4, n_inner: 1024, activation: relu, n_positions: 1024, resid_pdrop: 0.1, attn_pdrop: 0.1) + >>> # alternatively + >>> config = DecisionTransformer.DTConfig(n_embd=128) + >>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config) + >>> batch_size = [3, 32] + >>> length = 10 + >>> observation = torch.randn(*batch_size, length, 4) + >>> action = torch.randn(*batch_size, length, 2) + >>> return_to_go = torch.randn(*batch_size, length, 1) + >>> output = model(observation, action, return_to_go) + >>> output.shape + torch.Size([3, 32, 10, 128]) + + """ + + @dataclass + class DTConfig: + """Default configuration for DecisionTransformer.""" + + n_embd: Any = 256 + n_layer: Any = 4 + n_head: Any = 4 + n_inner: Any = 1024 + activation: Any = "relu" + n_positions: Any = 1024 + resid_pdrop: Any = 0.1 + attn_pdrop: Any = 0.1 + + def __repr__(self): + fields = [] + for f in dataclasses.fields(self): + value = getattr(self, f.name) + fields.append(f"{f.name}: {value}") + fields = ", ".join(fields) + return f"{self.__class__.__name__}({fields})" + + @classmethod + def default_config(cls): + return cls.DTConfig() + + def __init__( + self, + state_dim, + action_dim, + config: dict | DTConfig = None, + ): + if not _has_transformers: + raise ImportError( + "transformers is not installed. Please install it with `pip install transformers`." + ) + import transformers + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + if config is None: + config = self.default_config() + if isinstance(config, self.DTConfig): + config = dataclasses.asdict(config) + if not isinstance(config, dict): + try: + config = dict(config) + except Exception as err: + raise TypeError( + f"Config of type {type(config)} is not supported." + ) from err + + super(DecisionTransformer, self).__init__() + + gpt_config = transformers.GPT2Config( + n_embd=config["n_embd"], + n_layer=config["n_layer"], + n_head=config["n_head"], + n_inner=config["n_inner"], + activation_function=config["activation"], + n_positions=config["n_positions"], + resid_pdrop=config["resid_pdrop"], + attn_pdrop=config["attn_pdrop"], + vocab_size=1, + ) + self.state_dim = state_dim + self.action_dim = action_dim + self.hidden_size = config["n_embd"] + + self.transformer = GPT2Model(config=gpt_config) + + self.embed_return = torch.nn.Linear(1, self.hidden_size) + self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) + self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) + + self.embed_ln = nn.LayerNorm(self.hidden_size) + + def forward( + self, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, + ): + batch_size, seq_length = observation.shape[:-2], observation.shape[-2] + batch_size_orig = batch_size + if len(batch_size) != 1: + # TODO: vmap over transformer once this is possible + observation = observation.view(-1, *observation.shape[-2:]) + action = action.view(-1, *action.shape[-2:]) + return_to_go = return_to_go.view(-1, *return_to_go.shape[-2:]) + batch_size = torch.Size([batch_size.numel()]) + + # embed each modality with a different head + state_embeddings = self.embed_state(observation) + action_embeddings = self.embed_action(action) + returns_embeddings = self.embed_return(return_to_go) + + # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) + # which works nice in an autoregressive sense since states predict actions + stacked_inputs = ( + torch.stack( + (returns_embeddings, state_embeddings, action_embeddings), dim=-3 + ) + .permute(*range(len(batch_size)), -2, -3, -1) + .reshape(*batch_size, 3 * seq_length, self.hidden_size) + ) + stacked_inputs = self.embed_ln(stacked_inputs) + + # we feed in the input embeddings (not word indices as in NLP) to the model + transformer_outputs = self.transformer( + inputs_embeds=stacked_inputs, + ) + x = transformer_outputs["last_hidden_state"] + + # reshape x so that the second dimension corresponds to the original + # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t + x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).permute( + *range(len(batch_size)), -2, -3, -1 + ) + if batch_size_orig is batch_size: + return x[..., 1, :, :] # only state tokens + return x[..., 1, :, :].view(*batch_size_orig, *x.shape[-2:]) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index a72af43aa13..1cc10316045 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -2,6 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import dataclasses + import warnings from numbers import Number from typing import Dict, List, Optional, Sequence, Tuple, Type, Union @@ -13,12 +17,14 @@ from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING +from torchrl.modules.models.decision_transformer import DecisionTransformer from torchrl.modules.models.utils import ( _find_depth, create_on_device, LazyMapping, SquashDims, Squeeze2dLayer, + SqueezeLayer, ) @@ -203,7 +209,7 @@ def __init__( if not (len(self.num_cells) == depth or depth is None): raise RuntimeError( "depth and num_cells length conflict, \ - consider matching or specifying a constan num_cells argument together with a a desired depth" + consider matching or specifying a constant num_cells argument together with a a desired depth" ) layers = self._make_net(device) super().__init__(*layers) @@ -406,7 +412,7 @@ def __init__( if not (len(getattr(self, _field)) == _depth or _depth is None): raise RuntimeError( f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, " - + f"consider matching or specifying a constan {_field} argument together with a a desired depth" + + f"consider matching or specifying a constant {_field} argument together with a a desired depth" ) self.out_features = self.num_cells[-1] @@ -481,6 +487,230 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return out +Conv2dNet = ConvNet + + +class Conv3dNet(nn.Sequential): + """A 3D-convolutional neural network. + + Args: + in_features (int, optional): number of input features. A lazy implementation that automatically retrieves + the input size will be used if none is provided. + depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the + desired input size, and with an output size equal to the last element of the num_cells argument. + If no depth is indicated, the depth information should be contained in the num_cells argument (see below). + If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to + the depth. + num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers out_features will match the content of num_cells. + default: ``[32, 32, 32]`` or ``[32] * depth` is depth is not ``None``. + kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the + depth, defined by the num_cells or depth arguments. + strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the + depth, defined by the num_cells or depth arguments. + activation_class (Type[nn.Module]): activation class to be used. + default: nn.Tanh + activation_kwargs (dict, optional): kwargs to be used with the activation class; + norm_class (Type, optional): normalization class, if any; + norm_kwargs (dict, optional): kwargs to be used with the normalization layers; + bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. + default: True; + aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain. + default: SquashDims; + aggregator_kwargs (dict, optional): kwargs for the aggregator_class; + squeeze_output (bool): whether the output should be squeezed of its singleton dimensions. + default: False. + device (Optional[DEVICE_TYPING]): device to create the module on. + + Examples: + >>> # All of the following examples provide valid, working MLPs + >>> cnet = Conv3dNet(in_features=3, depth=1, num_cells=[32,]) + >>> print(cnet) + Conv3dNet( + (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (1): ELU(alpha=1.0) + (2): SquashDims() + ) + >>> cnet = Conv3dNet(in_features=3, depth=4, num_cells=32) + >>> print(cnet) + Conv3dNet( + (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (1): ELU(alpha=1.0) + (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (3): ELU(alpha=1.0) + (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (5): ELU(alpha=1.0) + (6): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (7): ELU(alpha=1.0) + (8): SquashDims() + ) + >>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg + >>> print(cnet) + Conv3dNet( + (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (1): ELU(alpha=1.0) + (2): Conv3d(32, 33, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (3): ELU(alpha=1.0) + (4): Conv3d(33, 34, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (5): ELU(alpha=1.0) + (6): Conv3d(34, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (7): ELU(alpha=1.0) + (8): SquashDims() + ) + >>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3, 4)]) # defines kernels, possibly rectangular + >>> print(cnet) + Conv3dNet( + (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) + (1): ELU(alpha=1.0) + (2): Conv3d(32, 33, kernel_size=(4, 4, 4), stride=(1, 1, 1)) + (3): ELU(alpha=1.0) + (4): Conv3d(33, 34, kernel_size=(5, 5, 5), stride=(1, 1, 1)) + (5): ELU(alpha=1.0) + (6): Conv3d(34, 35, kernel_size=(2, 3, 4), stride=(1, 1, 1)) + (7): ELU(alpha=1.0) + (8): SquashDims() + ) + + """ + + def __init__( + self, + in_features: Optional[int] = None, + depth: Optional[int] = None, + num_cells: Union[Sequence, int] = None, + kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3, + strides: Union[Sequence, int] = 1, + paddings: Union[Sequence, int] = 0, + activation_class: Type[nn.Module] = nn.ELU, + activation_kwargs: Optional[dict] = None, + norm_class: Optional[Type[nn.Module]] = None, + norm_kwargs: Optional[dict] = None, + bias_last_layer: bool = True, + aggregator_class: Optional[Type[nn.Module]] = SquashDims, + aggregator_kwargs: Optional[dict] = None, + squeeze_output: bool = False, + device: Optional[DEVICE_TYPING] = None, + ): + if num_cells is None: + if depth is None: + num_cells = [32, 32, 32] + else: + num_cells = [32] * depth + + self.in_features = in_features + self.activation_class = activation_class + self.activation_kwargs = ( + activation_kwargs if activation_kwargs is not None else {} + ) + self.norm_class = norm_class + self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self.bias_last_layer = bias_last_layer + self.aggregator_class = aggregator_class + self.aggregator_kwargs = ( + aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 4} + ) + self.squeeze_output = squeeze_output + # self.single_bias_last_layer = single_bias_last_layer + + depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings) + self.depth = depth + if depth == 0: + raise ValueError("Null depth is not permitted with Conv3dNet.") + + for _field, _value in zip( + ["num_cells", "kernel_sizes", "strides", "paddings"], + [num_cells, kernel_sizes, strides, paddings], + ): + _depth = depth + setattr( + self, + _field, + (_value if isinstance(_value, Sequence) else [_value] * _depth), + ) + if not (len(getattr(self, _field)) == _depth or _depth is None): + raise ValueError( + f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, " + + f"consider matching or specifying a constant {_field} argument together with a a desired depth" + ) + + self.out_features = self.num_cells[-1] + + self.depth = len(self.kernel_sizes) + layers = self._make_net(device) + super().__init__(*layers) + + def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: + layers = [] + in_features = [self.in_features] + self.num_cells[: self.depth] + out_features = self.num_cells + [self.out_features] + kernel_sizes = self.kernel_sizes + strides = self.strides + paddings = self.paddings + for i, (_in, _out, _kernel, _stride, _padding) in enumerate( + zip(in_features, out_features, kernel_sizes, strides, paddings) + ): + _bias = (i < len(in_features) - 1) or self.bias_last_layer + if _in is not None: + layers.append( + nn.Conv3d( + _in, + _out, + kernel_size=_kernel, + stride=_stride, + bias=_bias, + padding=_padding, + device=device, + ) + ) + else: + layers.append( + nn.LazyConv3d( + _out, + kernel_size=_kernel, + stride=_stride, + bias=_bias, + padding=_padding, + device=device, + ) + ) + + layers.append( + create_on_device( + self.activation_class, device, **self.activation_kwargs + ) + ) + if self.norm_class is not None: + layers.append( + create_on_device(self.norm_class, device, **self.norm_kwargs) + ) + + if self.aggregator_class is not None: + layers.append( + create_on_device( + self.aggregator_class, device, **self.aggregator_kwargs + ) + ) + + if self.squeeze_output: + layers.append(SqueezeLayer((-3, -2, -1))) + return layers + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + try: + *batch, C, D, L, W = inputs.shape + except ValueError as err: + raise ValueError( + f"The input value of {self.__class__.__name__} must have at least 4 dimensions, got {inputs.ndim} instead." + ) from err + if len(batch) > 1: + inputs = inputs.flatten(0, len(batch) - 1) + out = super().forward(inputs) + if len(batch) > 1: + out = out.unflatten(0, batch) + return out + + class DuelingMlpDQNet(nn.Module): """Creates a Dueling MLP Q-network. @@ -1135,6 +1365,184 @@ def forward( hidden0_in: Optional[torch.Tensor] = None, hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) + + +class OnlineDTActor(nn.Module): + """Online Decision Transformer Actor class. + + Actor class for the Online Decision Transformer to sample actions from gaussian distribution as presented inresented in `"Online Decision Transformer" `. + Returns mu and sigma for the gaussian distribution to sample actions from. + + Args: + state_dim (int): state dimension. + action_dim (int): action dimension. + transformer_config (Dict or :class:`DecisionTransformer.DTConfig`): + config for the GPT2 transformer. + Defaults to :meth:`~.default_config`. + device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. + + Examples: + >>> model = OnlineDTActor(state_dim=4, action_dim=2, + ... transformer_config=OnlineDTActor.default_config()) + >>> observation = torch.randn(32, 10, 4) + >>> action = torch.randn(32, 10, 2) + >>> return_to_go = torch.randn(32, 10, 1) + >>> mu, std = model(observation, action, return_to_go) + >>> mu.shape + torch.Size([32, 10, 2]) + >>> std.shape + torch.Size([32, 10, 2]) + """ + + def __init__( + self, + state_dim: int, + action_dim: int, + transformer_config: Dict | DecisionTransformer.DTConfig = None, + device: Optional[DEVICE_TYPING] = None, + ): + super().__init__() + if transformer_config is None: + transformer_config = self.default_config() + if isinstance(transformer_config, DecisionTransformer.DTConfig): + transformer_config = dataclasses.asdict(transformer_config) + self.transformer = DecisionTransformer( + state_dim=state_dim, + action_dim=action_dim, + config=transformer_config, + ) + self.action_layer_mean = nn.Linear( + transformer_config["n_embd"], action_dim, device=device + ) + self.action_layer_logstd = nn.Linear( + transformer_config["n_embd"], action_dim, device=device + ) + + self.log_std_min, self.log_std_max = -5.0, 2.0 + + def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, torch.nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + + self.action_layer_mean.apply(weight_init) + self.action_layer_logstd.apply(weight_init) + + def forward( + self, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + hidden_state = self.transformer(observation, action, return_to_go) + mu = self.action_layer_mean(hidden_state) + log_std = self.action_layer_logstd(hidden_state) + + log_std = torch.tanh(log_std) + # log_std is the output of tanh so it will be between [-1, 1] + # map it to be between [log_std_min, log_std_max] + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * ( + log_std + 1.0 + ) + std = log_std.exp() + + return mu, std + + @classmethod + def default_config(cls): + """Default configuration for :class:`~.OnlineDTActor`.""" + return DecisionTransformer.DTConfig( + n_embd=512, + n_layer=4, + n_head=4, + n_inner=2048, + activation="relu", + n_positions=1024, + resid_pdrop=0.1, + attn_pdrop=0.1, + ) + + +class DTActor(nn.Module): + """Decision Transformer Actor class. + + Actor class for the Decision Transformer to output deterministic action as presented in `"Decision Transformer" `. + Returns the deterministic actions. + + Args: + state_dim (int): state dimension. + action_dim (int): action dimension. + transformer_config (Dict or :class:`DecisionTransformer.DTConfig`, optional): + config for the GPT2 transformer. + Defaults to :meth:`~.default_config`. + device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. + + Examples: + >>> model = DTActor(state_dim=4, action_dim=2, + ... transformer_config=DTActor.default_config()) + >>> observation = torch.randn(32, 10, 4) + >>> action = torch.randn(32, 10, 2) + >>> return_to_go = torch.randn(32, 10, 1) + >>> output = model(observation, action, return_to_go) + >>> output.shape + torch.Size([32, 10, 2]) + + """ + + def __init__( + self, + state_dim: int, + action_dim: int, + transformer_config: Dict | DecisionTransformer.DTConfig = None, + device: Optional[DEVICE_TYPING] = None, + ): + super().__init__() + if transformer_config is None: + transformer_config = self.default_config() + if isinstance(transformer_config, DecisionTransformer.DTConfig): + transformer_config = dataclasses.asdict(transformer_config) + self.transformer = DecisionTransformer( + state_dim=state_dim, + action_dim=action_dim, + config=transformer_config, + ) + self.action_layer = nn.Linear( + transformer_config["n_embd"], action_dim, device=device + ) + + def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, torch.nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + + self.action_layer.apply(weight_init) + + def forward( + self, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, + ) -> torch.Tensor: + hidden_state = self.transformer(observation, action, return_to_go) + out = self.action_layer(hidden_state) + return out + + @classmethod + def default_config(cls): + """Default configuration for :class:`~.DTActor`.""" + return DecisionTransformer.DTConfig( + n_embd=512, + n_layer=4, + n_head=4, + n_inner=2048, + activation="relu", + n_positions=1024, + resid_pdrop=0.1, + attn_pdrop=0.1, + ) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index de565b336d2..8ebd97fdaa8 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -12,7 +12,7 @@ from ...data import DEVICE_TYPING -from .models import MLP +from .models import ConvNet, MLP class MultiAgentMLP(nn.Module): @@ -215,10 +215,10 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: if self.centralised: # If the parameters are shared, and it is centralised, all agents will have the same output # We expand it to maintain the agent dimension, but values will be the same for all agents - output = ( - output.view(*output.shape[:-1], self.n_agent_outputs) - .unsqueeze(-2) - .expand(*output.shape[:-1], self.n_agents, self.n_agent_outputs) + output = output.view(*output.shape[:-1], self.n_agent_outputs) + output = output.unsqueeze(-2) + output = output.expand( + *output.shape[:-2], self.n_agents, self.n_agent_outputs ) if output.shape[-2:] != (self.n_agents, self.n_agent_outputs): @@ -230,6 +230,228 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: return output +class MultiAgentConvNet(nn.Module): + """Multi-agent CNN. + + In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as "centralized" and "non-centralized", respectively. + + It expects inputs with shape ``(*B, n_agents, channels, x, y)``. + + Args: + n_agents (int): number of agents. + centralised (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input. + share_params (bool): If ``True``, the same :class:`~torchrl.modules.ConvNet` will be used to make the forward pass + for all agents (homogeneous policies). Otherwise, each agent will use a different :class:`~torchrl.modules.ConvNet` to process + its input (heterogeneous policies). + device (str or torch.device, optional): device to create the module on. + num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers ``out_features`` will match the content of ``num_cells``. + kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the convolutional network. + Defaults to ``5``. + strides (int or Sequence[int]): Stride(s) of the convolutional network. If iterable, the length must match the + depth, defined by the num_cells or depth arguments. + Defaults to ``2``. + activation_class (Type[nn.Module]): activation class to be used. + Default to :class:`torch.nn.ELU`. + **kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet. + + + Examples: + >>> import torch + >>> from torchrl.modules import MultiAgentConvNet + >>> batch = (3,2) + >>> n_agents = 7 + >>> channels, x, y = 3, 100, 100 + >>> obs = torch.randn(*batch, n_agents, channels, x, y) + >>> # First lets consider a centralised network with shared parameters. + >>> cnn = MultiAgentConvNet( + ... n_agents, + ... centralised = True, + ... share_params = True + ... ) + >>> print(cnn) + MultiAgentConvNet( + (agent_networks): ModuleList( + (0): ConvNet( + (0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (5): ELU(alpha=1.0) + (6): SquashDims() + ) + ) + ) + >>> result = cnn(obs) + >>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'. + >>> print(result.shape) + torch.Size([3, 2, 7, 2592]) + >>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function) + >>> print(all(result[0,0,0] == result[0,0,1])) + True + + >>> # Alternatively, a local network with parameter sharing (eg. decentralised weight sharing policy) + >>> cnn = MultiAgentConvNet( + ... n_agents, + ... centralised = False, + ... share_params = True + ... ) + >>> print(cnn) + MultiAgentConvNet( + (agent_networks): ModuleList( + (0): ConvNet( + (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (5): ELU(alpha=1.0) + (6): SquashDims() + ) + ) + ) + >>> print(result.shape) + torch.Size([3, 2, 7, 2592]) + >>> # Parameters are shared but not observations, hence each agent has a different output. + >>> print(all(result[0,0,0] == result[0,0,1])) + False + + >>> # Or multiple local networks identical in structure but with differing weights. + >>> cnn = MultiAgentConvNet( + ... n_agents, + ... centralised = False, + ... share_params = False + ... ) + >>> print(cnn) + MultiAgentConvNet( + (agent_networks): ModuleList( + (0-6): 7 x ConvNet( + (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (5): ELU(alpha=1.0) + (6): SquashDims() + ) + ) + ) + >>> print(result.shape) + torch.Size([3, 2, 7, 2592]) + >>> print(all(result[0,0,0] == result[0,0,1])) + False + + >>> # Or where inputs are shared but not parameters. + >>> cnn = MultiAgentConvNet( + ... n_agents, + ... centralised = True, + ... share_params = False + ... ) + >>> print(cnn) + MultiAgentConvNet( + (agent_networks): ModuleList( + (0-6): 7 x ConvNet( + (0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2)) + (5): ELU(alpha=1.0) + (6): SquashDims() + ) + ) + ) + >>> print(result.shape) + torch.Size([3, 2, 7, 2592]) + >>> print(all(result[0,0,0] == result[0,0,1])) + False + """ + + def __init__( + self, + n_agents: int, + centralised: bool, + share_params: bool, + device: Optional[DEVICE_TYPING] = None, + num_cells: Optional[Sequence[int]] = None, + kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5, + strides: Union[Sequence, int] = 2, + paddings: Union[Sequence, int] = 0, + activation_class: Type[nn.Module] = nn.ELU, + **kwargs, + ): + super().__init__() + + self.n_agents = n_agents + self.centralised = centralised + self.share_params = share_params + + self.agent_networks = nn.ModuleList( + [ + ConvNet( + num_cells=num_cells, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + activation_class=activation_class, + device=device, + **kwargs, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def forward(self, inputs: torch.Tensor): + if len(inputs.shape) < 4: + raise ValueError( + """Multi-agent network expects (*batch_size, agent_index, x, y, channels)""" + ) + if inputs.shape[-4] != self.n_agents: + raise ValueError( + f"""Multi-agent network expects {self.n_agents} but got {inputs.shape[-4]}""" + ) + # If the model is centralized, agents have full observability + if self.centralised: + shape = ( + *inputs.shape[:-4], + self.n_agents * inputs.shape[-3], + inputs.shape[-2], + inputs.shape[-1], + ) + inputs = torch.reshape(inputs, shape) + + # If the parameters are not shared, each agent has its own network + if not self.share_params: + if self.centralised: + output = torch.stack( + [net(inputs) for net in self.agent_networks], dim=-2 + ) + else: + output = torch.stack( + [ + net(inp) + for i, (net, inp) in enumerate( + zip(self.agent_networks, inputs.unbind(-4)) + ) + ], + dim=-2, + ) + else: + output = self.agent_networks[0](inputs) + if self.centralised: + # If the parameters are shared, and it is centralised all agents will have the same output. + # We expand it to maintain the agent dimension, but values will be the same for all agents + n_agent_outputs = output.shape[-1] + output = output.view(*output.shape[:-1], n_agent_outputs) + output = output.unsqueeze(-2) + output = output.expand( + *output.shape[:-2], self.n_agents, n_agent_outputs + ) + return output + + class Mixer(nn.Module): """A multi-agent value mixer. diff --git a/torchrl/modules/models/rlhf.py b/torchrl/modules/models/rlhf.py index 066be2a5ad6..48953e43a4a 100644 --- a/torchrl/modules/models/rlhf.py +++ b/torchrl/modules/models/rlhf.py @@ -41,7 +41,7 @@ def __init__(self, model_path=None): from transformers import GPT2LMHeadModel, GPT2TokenizerFast super().__init__() - if model_path: + if model_path is not None: model = GPT2LMHeadModel.from_pretrained(model_path, return_dict=False) else: model = GPT2LMHeadModel(GPT2LMHeadModel.config_class()) @@ -75,6 +75,7 @@ def _compute_end_scores(self, rewards, input_ids): return torch.stack(end_scores) + # TODO: move to objectives @staticmethod def compute_reward_loss(chosen_batch, rejected_batch, pad_token_id=50256): """Compute the reward loss given a chosen and rejected batch. diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 3e8515ab7aa..b4fa7eb58fd 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -96,7 +96,7 @@ def _find_depth(depth: Optional[int], *list_or_ints: Sequence): if isinstance(item, (list, tuple)): depth = len(item) if depth is None: - raise Exception( + raise ValueError( f"depth=None requires one of the input args (kernel_sizes, strides, " f"num_cells) to be a a list or tuple. Got {tuple(type(item) for item in list_or_ints)}" ) diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index a56f383e929..3cf183b5e51 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -6,7 +6,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.envs import EnvBase +from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 057efa1ef3a..66fd1bb9e1f 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -8,7 +8,7 @@ import torch from tensordict.tensordict import TensorDictBase -from torchrl.envs import EnvBase +from torchrl.envs.common import EnvBase from torchrl.modules import SafeModule diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index 357cfb49bea..e41f98a2852 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -7,7 +7,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn -from torchrl.envs import EnvBase +from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 98db8c87663..7605238f99a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -8,6 +8,7 @@ ActorCriticOperator, ActorCriticWrapper, ActorValueOperator, + DecisionTransformerInferenceWrapper, DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, @@ -22,6 +23,7 @@ from .common import SafeModule, VmapModule from .exploration import ( AdditiveGaussianWrapper, + EGreedyModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) @@ -29,6 +31,6 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import LSTMModule +from .rnn import GRUModule, LSTMModule from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 59688b68feb..4defee3965a 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -7,7 +7,7 @@ import torch -from tensordict import TensorDictBase +from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( dispatch, TensorDictModule, @@ -20,6 +20,7 @@ from torch.distributions import Categorical from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.utils import _process_action_space_spec from torchrl.modules.models.models import DistributionalDQNnet from torchrl.modules.tensordict_module.common import SafeModule from torchrl.modules.tensordict_module.probabilistic import ( @@ -27,7 +28,6 @@ SafeProbabilisticTensorDictSequential, ) from torchrl.modules.tensordict_module.sequence import SafeSequential -from torchrl.modules.utils.utils import _find_action_space class Actor(SafeModule): @@ -188,7 +188,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]), - ... minimum=-1, maximum=1) + ... low=-1, high=1) >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( @@ -299,7 +299,6 @@ def __init__( in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, ) -> None: - if in_keys is None: in_keys = ["observation"] if out_keys is None: @@ -328,6 +327,8 @@ class QValueModule(TensorDictModuleBase): conditions the action_space. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. @@ -379,6 +380,7 @@ def __init__( self, action_space: Optional[str], action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, var_nums: Optional[int] = None, spec: Optional[TensorSpec] = None, @@ -408,7 +410,11 @@ def __init__( ) if action_value_key is None: action_value_key = "action_value" - self.in_keys = [action_value_key] + self.action_mask_key = action_mask_key + in_keys = [action_value_key] + if self.action_mask_key is not None: + in_keys.append(self.action_mask_key) + self.in_keys = in_keys if out_keys is None: out_keys = ["action", action_value_key, "chosen_action_value"] elif action_value_key not in out_keys: @@ -447,6 +453,15 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action value key {self.action_value_key} not found in {tensordict}." ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + action_values = torch.where( + action_mask, action_values, torch.finfo(action_values.dtype).min + ) action = self.action_func_mapping[self.action_space](action_values) @@ -468,11 +483,17 @@ def _one_hot(value: torch.Tensor) -> torch.Tensor: def _categorical(value: torch.Tensor) -> torch.Tensor: return torch.argmax(value, dim=-1).to(torch.long) - def _mult_one_hot(self, value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: + def _mult_one_hot( + self, value: torch.Tensor, support: torch.Tensor = None + ) -> torch.Tensor: + if self.var_nums is None: + raise ValueError( + "var_nums must be provided to the constructor for multi one-hot action spaces." + ) values = value.split(self.var_nums, dim=-1) return torch.cat( [ - QValueHook._one_hot( + self._one_hot( _value, ) for _value in values @@ -523,6 +544,8 @@ class DistributionalQValueModule(QValueModule): support (torch.Tensor): support of the action values. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): The output keys representing the actions and action values. Defaults to ``["action", "action_value"]``. @@ -578,6 +601,7 @@ def __init__( action_space: Optional[str], support: torch.Tensor, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, var_nums: Optional[int] = None, spec: TensorSpec = None, @@ -590,6 +614,7 @@ def __init__( super().__init__( action_space=action_space, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, var_nums=var_nums, spec=spec, @@ -604,6 +629,15 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action value key {self.action_value_key} not found in {tensordict}." ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + action_values = torch.where( + action_mask, action_values, torch.finfo(action_values.dtype).min + ) action = self.action_func_mapping[self.action_space](action_values) @@ -676,59 +710,6 @@ def _binary(self, value: torch.Tensor) -> torch.Tensor: ) -def _process_action_space_spec(action_space, spec): - original_spec = spec - composite_spec = False - if isinstance(spec, CompositeSpec): - # this will break whenever our action is more complex than a single tensor - try: - if "action" in spec.keys(): - _key = "action" - else: - # the first key is the action - for _key in spec.keys(True, True): - if isinstance(_key, tuple) and _key[-1] == "action": - break - else: - raise KeyError - spec = spec[_key] - composite_spec = True - except KeyError: - raise KeyError( - "action could not be found in the spec. Make sure " - "you pass a spec that is either a native action spec or a composite action spec " - "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." - ) - if action_space is not None: - if isinstance(action_space, CompositeSpec): - raise ValueError("action_space cannot be of type CompositeSpec.") - if ( - spec is not None - and isinstance(action_space, TensorSpec) - and action_space is not spec - ): - raise ValueError( - "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." - ) - if isinstance(action_space, TensorSpec): - spec = action_space - action_space = _find_action_space(action_space) - # check that the spec and action_space match - if spec is not None and _find_action_space(spec) != action_space: - raise ValueError( - f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}." - ) - elif spec is not None: - action_space = _find_action_space(spec) - else: - raise ValueError( - "Neither action_space nor spec was defined. The action space cannot be inferred." - ) - if composite_spec: - spec = original_spec - return action_space, spec - - class QValueHook: """Q-Value hook for Q-value policies. @@ -746,6 +727,8 @@ class QValueHook: action_value_key (str or tuple of str, optional): to be used when hooked on a TensorDictModule. The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): to be used when hooked on a TensorDictModule. The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. @@ -781,6 +764,7 @@ def __init__( action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): @@ -795,6 +779,7 @@ def __init__( action_space=action_space, var_nums=var_nums, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, ) action_value_key = self.qvalue_model.in_keys[0] @@ -824,6 +809,11 @@ class DistributionalQValueHook(QValueHook): Args: action_space (str): Action space. Must be one of ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``. + action_value_key (str or tuple of str, optional): to be used when hooked on + a TensorDictModule. The input key representing the action value. Defaults + to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). support (torch.Tensor): support of the action values. var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each @@ -871,6 +861,7 @@ def __init__( support: torch.Tensor, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): @@ -885,6 +876,7 @@ def __init__( var_nums=var_nums, support=support, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, ) action_value_key = self.qvalue_model.in_keys[0] @@ -932,6 +924,8 @@ class QValueActor(SafeSequential): is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must match one of its output keys. Otherwise, this string represents the name of the action-value entry in the output tensordict. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). .. note:: ``out_keys`` cannot be passed. If the module is a :class:`tensordict.nn.TensorDictModule` @@ -990,6 +984,7 @@ def __init__( safe=False, action_space: Optional[str] = None, action_value_key=None, + action_mask_key: Optional[NestedKey] = None, ): if isinstance(action_space, TensorSpec): warnings.warn( @@ -1035,6 +1030,7 @@ def __init__( spec=spec, safe=safe, action_space=action_space, + action_mask_key=action_mask_key, ) super().__init__(module, qvalue) @@ -1083,6 +1079,12 @@ class DistributionalQValueActor(QValueActor): make_log_softmax (bool, optional): if ``True`` and if the module is not of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax operation will be applied along dimension -2 of the action value tensor. + action_value_key (str or tuple of str, optional): if the input module + is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must + match one of its output keys. Otherwise, this string represents + the name of the action-value entry in the output tensordict. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). Examples: >>> import torch @@ -1127,6 +1129,7 @@ def __init__( var_nums: Optional[int] = None, action_space: Optional[str] = None, action_value_key: str = "action_value", + action_mask_key: Optional[NestedKey] = None, make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): @@ -1169,6 +1172,7 @@ def __init__( spec=spec, safe=safe, action_space=action_space, + action_mask_key=action_mask_key, support=support, var_nums=var_nums, ) @@ -1612,6 +1616,218 @@ def get_value_operator(self) -> SafeSequential: get_value_head = get_value_operator +class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): + """Inference Action Wrapper for the Decision Transformer. + + A wrapper specifically designed for the Decision Transformer, which will mask the + input tensordict sequences to the inferece context. + The output will be a TensorDict with the same keys as the input, but with only the last + action of the predicted action sequence and the last return to go. + + This module creates returns a modified copy of the tensordict, ie. it does + **not** modify the tensordict in-place. + + .. note:: If the action, observation or reward-to-go key is not standard, + the method :meth:`~.set_tensor_keys` should be used, e.g. + + >>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz") + + The in_keys are the observation, action and return-to-go keys. The out-keys + match the in-keys, with the addition of any other out-key from the policy + (eg., parameters of the distribution or hidden values). + + Args: + policy (TensorDictModule): The policy module that takes in + observations and produces an action value + + Keyword Args: + inference_context (int): The number of previous actions that will not be masked in the context. + For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries + of the context will be masked. Defaults to 5. + spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules import ( + ... ProbabilisticActor, + ... TanhDelta, + ... DTActor, + ... DecisionTransformerInferenceWrapper, + ... ) + >>> dtactor = DTActor(state_dim=4, action_dim=2, + ... transformer_config=DTActor.default_config() + ... ) + >>> actor_module = TensorDictModule( + ... dtactor, + ... in_keys=["observation", "action", "return_to_go"], + ... out_keys=["param"]) + >>> dist_class = TanhDelta + >>> dist_kwargs = { + ... "min": -1.0, + ... "max": 1.0, + ... } + >>> actor = ProbabilisticActor( + ... in_keys=["param"], + ... out_keys=["action"], + ... module=actor_module, + ... distribution_class=dist_class, + ... distribution_kwargs=dist_kwargs) + >>> inference_actor = DecisionTransformerInferenceWrapper(actor) + >>> sequence_length = 20 + >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), + ... "action": torch.randn(1, sequence_length, 2), + ... "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) + >>> result = inference_actor(td) + >>> print(result) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), + param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False), + return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([1]), + device=None, + is_shared=False) + """ + + def __init__( + self, + policy: TensorDictModule, + *, + inference_context: int = 5, + spec: Optional[TensorSpec] = None, + ): + super().__init__(policy) + self.observation_key = "observation" + self.action_key = "action" + self.return_to_go_key = "return_to_go" + self.inference_context = inference_context + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({self.action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + elif hasattr(self.td_module, "_spec"): + self._spec = self.td_module._spec.clone() + if self.action_key not in self._spec.keys(): + self._spec[self.action_key] = None + elif hasattr(self.td_module, "spec"): + self._spec = self.td_module.spec.clone() + if self.action_key not in self._spec.keys(): + self._spec[self.action_key] = None + else: + self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self.checked = False + + @property + def in_keys(self): + return [self.observation_key, self.action_key, self.return_to_go_key] + + @property + def out_keys(self): + return sorted( + set(self.td_module.out_keys).union( + {self.observation_key, self.action_key, self.return_to_go_key} + ), + key=str, + ) + + def set_tensor_keys(self, **kwargs): + """Sets the input keys of the module. + + Keyword Args: + observation (NestedKey, optional): The observation key. + action (NestedKey, optional): The action key. + return_to_go (NestedKey, optional): The return_to_go key. + + """ + observation_key = unravel_key(kwargs.pop("observation", self.observation_key)) + action_key = unravel_key(kwargs.pop("action", self.action_key)) + return_to_go_key = unravel_key( + kwargs.pop("return_to_go", self.return_to_go_key) + ) + if kwargs: + raise TypeError( + f"Got unknown input(s) {kwargs.keys()}. Accepted keys are 'action', 'return_to_go' and 'observation'." + ) + if action_key not in self.td_module.out_keys: + raise ValueError( + f"The action key {action_key} was not found in the policy out_keys {self.td_module.out_keys}." + ) + self.observation_key = observation_key + self.action_key = action_key + self.return_to_go_key = return_to_go_key + + def step(self, frames: int = 1) -> None: + pass + + @staticmethod + def _check_tensor_dims(reward, obs, action): + if not (reward.shape[:-1] == obs.shape[:-1] == action.shape[:-1]): + raise ValueError( + "Mismatched tensor dimensions. This is not supported yet, file an issue on torchrl" + ) + + def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: + """Mask the context of the input sequences.""" + observation = tensordict.get(self.observation_key).clone() + action = tensordict.get(self.action_key).clone() + return_to_go = tensordict.get(self.return_to_go_key).clone() + self._check_tensor_dims(return_to_go, observation, action) + + observation[..., : -self.inference_context, :] = 0 + action[ + ..., : -(self.inference_context - 1), : + ] = 0 # as we add zeros to the end of the action + action = torch.cat( + [ + action[..., 1:, :], + torch.zeros( + *action.shape[:-2], 1, action.shape[-1], device=action.device + ), + ], + dim=-2, + ) + return_to_go[..., : -self.inference_context, :] = 0 + + tensordict.set(self.observation_key, observation) + tensordict.set(self.action_key, action) + tensordict.set(self.return_to_go_key, return_to_go) + return tensordict + + def check_keys(self): + # an exception will be raised if the action key mismatch + self.set_tensor_keys() + self.checked = True + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self.checked: + self.check_keys() + """Forward pass of the inference wrapper.""" + tensordict = tensordict.clone(False) + obs = tensordict.get(self.observation_key) + # Mask the context of the input sequences + tensordict = self.mask_context(tensordict) + # forward pass + tensordict = self.td_module.forward(tensordict) + # get last action predicton + out_action = tensordict.get(self.action_key) + if tensordict.ndim == out_action.ndim - 1: + # then time dimension is in the TD's dimensions, and we must get rid of it + tensordict.batch_size = tensordict.batch_size[:-1] + out_action = out_action[..., -1, :] + tensordict.set(self.action_key, out_action) + # out_rtg = tensordict.get(self.return_to_go_key)[:, -1] + out_rtg = tensordict.get(self.return_to_go_key) + out_rtg = out_rtg[..., -1, :] + tensordict.set(self.return_to_go_key, out_rtg) + # set unmasked observation + tensordict.set(self.observation_key, obs) + return tensordict + + class TanhModule(TensorDictModuleBase): """A Tanh module for deterministic policies with bounded action space. @@ -1740,22 +1956,22 @@ def _make_low_high(self, low, high, leaf_spec): if low is None and leaf_spec is None: low = -torch.ones(()) elif low is None: - low = leaf_spec.space.minimum + low = leaf_spec.space.low elif leaf_spec is not None: - if (low != leaf_spec.space.minimum).any(): + if (low != leaf_spec.space.low).any(): raise ValueError( - f"The minimum value ({low}) provided to {type(self)} does not match the action spec one ({leaf_spec.space.minimum})." + f"The minimum value ({low}) provided to {type(self)} does not match the action spec one ({leaf_spec.space.low})." ) if not isinstance(low, torch.Tensor): low = torch.tensor(low) if high is None and leaf_spec is None: high = torch.ones(()) elif high is None: - high = leaf_spec.space.maximum + high = leaf_spec.space.high elif leaf_spec is not None: - if (high != leaf_spec.space.maximum).any(): + if (high != leaf_spec.space.high).any(): raise ValueError( - f"The maximum value ({high}) provided to {type(self)} does not match the action spec one ({leaf_spec.space.maximum})." + f"The maximum value ({high}) provided to {type(self)} does not match the action spec one ({leaf_spec.space.high})." ) if not isinstance(high, torch.Tensor): high = torch.tensor(high) @@ -1822,4 +2038,4 @@ def __init__(self, base_model): value_head, in_keys=["x"], out_keys=["state_value"] ) - return super().__init__(common, actor_head, value_head) + super().__init__(common, actor_head, value_head) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 20d26b7aabd..d2e8ed8e3a1 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,7 +7,12 @@ import numpy as np import torch -from tensordict.nn import TensorDictModule, TensorDictModuleWrapper + +from tensordict.nn import ( + TensorDictModule, + TensorDictModuleBase, + TensorDictModuleWrapper, +) from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right, expand_right, NestedKey @@ -17,13 +22,168 @@ __all__ = [ "EGreedyWrapper", + "EGreedyModule", "AdditiveGaussianWrapper", "OrnsteinUhlenbeckProcessWrapper", ] +class EGreedyModule(TensorDictModuleBase): + """Epsilon-Greedy exploration module. + + This module randomly updates the action(s) in a tensordict given an epsilon greedy exploration strategy. + At each call, random draws (one per action) are executed given a certain probability threshold. If successful, + the corresponding actions are being replaced by random samples drawn from the action spec provided. + Others are left unchanged. + + Args: + spec (TensorSpec): the spec used for sampling actions. + eps_init (scalar, optional): initial epsilon value. + default: 1.0 + eps_end (scalar, optional): final epsilon value. + default: 0.1 + annealing_num_steps (int, optional): number of steps it will take for epsilon to reach + the ``eps_end`` value. Defaults to `1000`. + + Keyword Args: + action_key (NestedKey, optional): the key where the action can be found in the input tensordict. + Default is ``"action"``. + action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. + Default is ``None`` (corresponding to no mask). + + .. note:: + It is crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictSequential + >>> from torchrl.modules import EGreedyModule, Actor + >>> from torchrl.data import BoundedTensorSpec + >>> torch.manual_seed(0) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(spec=spec, module=module) + >>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2)) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td).get("action")) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.9055, -0.9277, -0.6295, -0.2532], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) + + """ + + def __init__( + self, + spec: TensorSpec, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + *, + action_key: Optional[NestedKey] = "action", + action_mask_key: Optional[NestedKey] = None, + ): + self.action_key = action_key + self.action_mask_key = action_mask_key + in_keys = [self.action_key] + if self.action_mask_key is not None: + in_keys.append(self.action_mask_key) + self.in_keys = in_keys + self.out_keys = [self.action_key] + + super().__init__() + + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise RuntimeError("eps should decrease over time or be constant") + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init])) + + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """A step of epsilon decay. + + After `self.annealing_num_steps` calls to this method, calls result in no-op. + + Args: + frames (int, optional): number of frames since last step. Defaults to ``1``. + + """ + for _ in range(frames): + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + if isinstance(self.action_key, tuple) and len(self.action_key) > 1: + action_tensordict = tensordict.get(self.action_key[:-1]) + action_key = self.action_key[-1] + else: + action_tensordict = tensordict + action_key = self.action_key + + out = action_tensordict.get(action_key) + eps = self.eps.item() + cond = ( + torch.rand(action_tensordict.shape, device=action_tensordict.device) + < eps + ).to(out.dtype) + cond = expand_as_right(cond, out) + spec = self.spec + if spec is not None: + if isinstance(spec, CompositeSpec): + spec = spec[self.action_key] + if spec.shape != out.shape: + # In batched envs if the spec is passed unbatched, the rand() will not + # cover all batched dims + if ( + not len(spec.shape) + or out.shape[-len(spec.shape) :] == spec.shape + ): + spec = spec.expand(out.shape) + else: + raise ValueError( + "Action spec shape does not match the action shape" + ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + spec.update_mask(action_mask) + out = cond * spec.rand().to(out.device) + (1 - cond) * out + else: + raise RuntimeError("spec must be provided to the exploration wrapper.") + action_tensordict.set(action_key, out) + return tensordict + + class EGreedyWrapper(TensorDictModuleWrapper): - """Epsilon-Greedy PO wrapper. + """[Deprecated] Epsilon-Greedy PO wrapper. Args: policy (TensorDictModule): a deterministic policy. @@ -34,16 +194,16 @@ class EGreedyWrapper(TensorDictModuleWrapper): eps_end (scalar, optional): final epsilon value. default: 0.1 annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value - action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to - find the action spec. - Default is "action". + action_key (NestedKey, optional): the key where the action can be found in the input tensordict. + Default is ``"action"``. + action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. + Default is ``None`` (corresponding to no mask). spec (TensorSpec, optional): if provided, the sampled action will be - projected onto the valid action space once explored. If not provided, + taken from this action space. If not provided, the exploration wrapper will attempt to recover it from the policy. .. note:: - Once an environment has been wrapped in :class:`EGreedyWrapper`, it is + Once a module has been wrapped in :class:`EGreedyWrapper`, it is crucial to incorporate a call to :meth:`~.step` in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception @@ -82,8 +242,15 @@ def __init__( eps_end: float = 0.1, annealing_num_steps: int = 1000, action_key: Optional[NestedKey] = "action", + action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): + warnings.warn( + "EGreedyWrapper is deprecated and it will be removed in v0.3. " + "Please use torchrl.modules.EGreedyModule instead.", + category=DeprecationWarning, + ) + super().__init__(policy) self.register_buffer("eps_init", torch.tensor([eps_init])) self.register_buffer("eps_end", torch.tensor([eps_end])) @@ -92,6 +259,7 @@ def __init__( self.annealing_num_steps = annealing_num_steps self.register_buffer("eps", torch.tensor([eps_init])) self.action_key = action_key + self.action_mask_key = action_mask_key if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) @@ -105,7 +273,7 @@ def __init__( if action_key not in self._spec.keys(): self._spec[action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = spec @property def spec(self): @@ -149,6 +317,25 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if spec is not None: if isinstance(spec, CompositeSpec): spec = spec[self.action_key] + if spec.shape != out.shape: + # In batched envs if the spec is passed unbatched, the rand() will not + # cover all batched dims + if ( + not len(spec.shape) + or out.shape[-len(spec.shape) :] == spec.shape + ): + spec = spec.expand(out.shape) + else: + raise ValueError( + "Action spec shape does not match the action shape" + ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + spec.update_mask(action_mask) out = cond * spec.rand().to(out.device) + (1 - cond) * out else: raise RuntimeError( diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 9e7e4421844..22be1432edf 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,14 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import Optional, Tuple import torch -from tensordict import unravel_key_list +from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import TensorDictModuleBase as ModuleBase -from tensordict.tensordict import NO_DEFAULT, TensorDictBase +from tensordict.tensordict import NO_DEFAULT from tensordict.utils import prod from torch import nn @@ -35,10 +36,10 @@ class LSTMModule(ModuleBase): multi-step. This class enables both usages. - After construction, the module is *not* set in temporal mode, ie. it will + After construction, the module is *not* set in recurrent mode, ie. it will expect single steps inputs. - If in temporal mode, it is expected that the last dimension of the tensordict + If in recurrent mode, it is expected that the last dimension of the tensordict marks the number of steps. There is no constrain on the dimensionality of the tensordict (except that it must be greater than one for temporal inputs). @@ -61,7 +62,6 @@ class LSTMModule(ModuleBase): dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 - proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 Keyword Args: in_key (str or tuple of str): the input key of the module. Exclusive use @@ -86,15 +86,15 @@ class LSTMModule(ModuleBase): Exclusive with other nn.LSTM arguments. Attributes: - temporal_mode: Returns the temporal mode of the module. + recurrent_mode: Returns the recurrent mode of the module. Methods: - set_temporal_mode: controls whether the module should be executed in - temporal mode. + set_recurrent_mode: controls whether the module should be executed in + recurrent mode. Examples: >>> from torchrl.envs import TransformedEnv, InitTracker - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod @@ -121,6 +121,8 @@ class LSTMModule(ModuleBase): device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) @@ -205,10 +207,10 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._temporal_mode = False + self._recurrent_mode = False def make_tensordict_primer(self): - from torchrl.envs import TensorDictPrimer + from torchrl.envs.transforms.transforms import TensorDictPrimer def make_tuple(key): if isinstance(key, tuple): @@ -237,29 +239,39 @@ def make_tuple(key): ) @property - def temporal_mode(self): - return self._temporal_mode + def recurrent_mode(self): + return self._recurrent_mode - @temporal_mode.setter - def temporal_mode(self, value): - raise RuntimeError("temporal_mode cannot be changed in-place. Call `module.set") + @recurrent_mode.setter + def recurrent_mode(self, value): + raise RuntimeError( + "recurrent_mode cannot be changed in-place. Call `module.set" + ) + + @property + def temporal_mode(self): + warnings.warn( + "temporal_mode is deprecated, use recurrent_mode instead.", + category=DeprecationWarning, + ) + return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same lstm model but with a different ``temporal_mode`` attribute (if it differs). + """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behaviour in various parts of the code (inference vs training): Examples: >>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) - >>> lstm_module = LSTMModule(lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) + >>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviours: >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) @@ -275,10 +287,10 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) """ - if mode is self._temporal_mode: + if mode is self._recurrent_mode: return self out = LSTMModule(lstm=self.lstm, in_keys=self.in_keys, out_keys=self.out_keys) - out._temporal_mode = mode + out._recurrent_mode = mode return out def forward(self, tensordict: TensorDictBase): @@ -286,7 +298,7 @@ def forward(self, tensordict: TensorDictBase): defaults = [NO_DEFAULT, None, None] shape = tensordict.shape tensordict_shaped = tensordict - if self.temporal_mode: + if self.recurrent_mode: # if less than 2 dims, unsqueeze ndim = tensordict_shaped.get(self.in_keys[0]).ndim while ndim < 3: @@ -301,11 +313,11 @@ def forward(self, tensordict: TensorDictBase): batch_size=[nelts, tensordict_shaped.shape[-1]], ) else: - tensordict_shaped = tensordict.view(-1).unsqueeze(-1) + tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) is_init = tensordict_shaped.get("is_init").squeeze(-1) splits = None - if self.temporal_mode and is_init[..., 1:].any(): + if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape @@ -340,7 +352,7 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped.set(self.out_keys[2], hidden1) if splits is not None: # let's recover our original shape - tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).view( + tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape( tensordict_shaped_shape ) @@ -359,7 +371,7 @@ def _lstm( hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not self.temporal_mode and steps != 1: + if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") if hidden1_in is None and hidden0_in is None: @@ -399,3 +411,389 @@ def _lstm( 1, ) return tuple(out) + + +class GRUModule(ModuleBase): + """An embedder for an GRU module. + + This class adds the following functionality to :class:`torch.nn.GRU`: + + - Compatibility with TensorDict: the hidden states are reshaped to match + the tensordict batch size. + - Optional multi-step execution: with torch.nn, one has to choose between + :class:`torch.nn.GRUCell` and :class:`torch.nn.GRU`, the former being + compatible with single step inputs and the latter being compatible with + multi-step. This class enables both usages. + + + After construction, the module is *not* set in recurrent mode, ie. it will + expect single steps inputs. + + If in recurrent mode, it is expected that the last dimension of the tensordict + marks the number of steps. There is no constrain on the dimensionality of the + tensordict (except that it must be greater than one for temporal inputs). + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights. + Default: ``True`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + proj_size: If ``> 0``, will use GRU with projections of corresponding size. Default: 0 + + Keyword Args: + in_key (str or tuple of str): the input key of the module. Exclusive use + with ``in_keys``. If provided, the recurrent keys are assumed to be + ["recurrent_state"] and the ``in_key`` will be + appended before this. + in_keys (list of str): a pair of strings corresponding to the input value and recurrent entry. + Exclusive with ``in_key``. + out_key (str or tuple of str): the output key of the module. Exclusive use + with ``out_keys``. If provided, the recurrent keys are assumed to be + [("recurrent_state")] and the ``out_key`` will be + appended before these. + out_keys (list of str): a pair of strings corresponding to the output value, + first and second hidden key. + .. note:: + For a better integration with TorchRL's environments, the best naming + for the output hidden key is ``("next", )``, such + that the hidden values are passed from step to step during a rollout. + device (torch.device or compatible): the device of the module. + gru (torch.nn.GRU, optional): a GRU instance to be wrapped. + Exclusive with other nn.GRU arguments. + + Attributes: + recurrent_mode: Returns the recurrent mode of the module. + + Methods: + set_recurrent_mode: controls whether the module should be executed in + recurrent mode. + + Examples: + >>> from torchrl.envs import TransformedEnv, InitTracker + >>> from torchrl.envs import GymEnv + >>> from torchrl.modules import MLP + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> gru_module = GRUModule( + ... input_size=env.observation_spec["observation"].shape[-1], + ... hidden_size=64, + ... in_keys=["observation", "rs"], + ... out_keys=["intermediate", ("next", "rs")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy(env.reset()) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), + is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> gru_module_training = gru_module.set_recurrent_mode() + >>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> traj_td = env.rollout(3) # some random temporal data + >>> traj_td = policy_training(traj_td) + >>> print(traj_td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False), + is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + + """ + + DEFAULT_IN_KEYS = ["recurrent_state"] + DEFAULT_OUT_KEYS = [("next", "recurrent_state")] + + def __init__( + self, + input_size: int = None, + hidden_size: int = None, + num_layers: int = 1, + bias: bool = True, + batch_first=True, + dropout=0, + bidirectional=False, + *, + in_key=None, + in_keys=None, + out_key=None, + out_keys=None, + device=None, + gru=None, + ): + super().__init__() + if gru is not None: + if not gru.batch_first: + raise ValueError("The input gru must have batch_first=True.") + if gru.bidirectional: + raise ValueError("The input gru cannot be bidirectional.") + if input_size is not None or hidden_size is not None: + raise ValueError( + "An GRU instance cannot be passed along with class argument." + ) + else: + if not batch_first: + raise ValueError("The input gru must have batch_first=True.") + if bidirectional: + raise ValueError("The input gru cannot be bidirectional.") + gru = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + dropout=dropout, + device=device, + batch_first=True, + bidirectional=False, + ) + if not ((in_key is None) ^ (in_keys is None)): + raise ValueError( + f"Either in_keys or in_key must be specified but not both or none. Got {in_keys} and {in_key} respectively." + ) + elif in_key: + in_keys = [in_key, *self.DEFAULT_IN_KEYS] + + if not ((out_key is None) ^ (out_keys is None)): + raise ValueError( + f"Either out_keys or out_key must be specified but not both or none. Got {out_keys} and {out_key} respectively." + ) + elif out_key: + out_keys = [out_key, *self.DEFAULT_OUT_KEYS] + + in_keys = unravel_key_list(in_keys) + out_keys = unravel_key_list(out_keys) + if not isinstance(in_keys, (tuple, list)) or ( + len(in_keys) != 2 and not (len(in_keys) == 3 and in_keys[-1] == "is_init") + ): + raise ValueError( + f"GRUModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead." + ) + if not isinstance(out_keys, (tuple, list)) or len(out_keys) != 2: + raise ValueError( + f"GRUModule expects 3 outputs: a value, and two hidden states. Got out_keys {out_keys} instead." + ) + self.gru = gru + if "is_init" not in in_keys: + in_keys = in_keys + ["is_init"] + self.in_keys = in_keys + self.out_keys = out_keys + self._recurrent_mode = False + + def make_tensordict_primer(self): + from torchrl.envs import TensorDictPrimer + + def make_tuple(key): + if isinstance(key, tuple): + return key + return (key,) + + out_key1 = make_tuple(self.out_keys[1]) + in_key1 = make_tuple(self.in_keys[1]) + if out_key1 != ("next", *in_key1): + raise RuntimeError( + "make_tensordict_primer is supposed to work with in_keys/out_keys that " + "have compatible names, ie. the out_keys should be named after ('next', ). Got " + f"in_keys={self.in_keys} and out_keys={self.out_keys} instead." + ) + return TensorDictPrimer( + { + in_key1: UnboundedContinuousTensorSpec( + shape=(self.gru.num_layers, self.gru.hidden_size) + ), + } + ) + + @property + def recurrent_mode(self): + return self._recurrent_mode + + @recurrent_mode.setter + def recurrent_mode(self, value): + raise RuntimeError( + "recurrent_mode cannot be changed in-place. Call `module.set" + ) + + @property + def temporal_mode(self): + warnings.warn( + "temporal_mode is deprecated, use recurrent_mode instead.", + category=DeprecationWarning, + ) + return self.recurrent_mode + + def set_recurrent_mode(self, mode: bool = True): + """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). + + A copy is created such that the module can be used with divergent behaviour + in various parts of the code (inference vs training): + + Examples: + >>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp + >>> from torchrl.modules import MLP + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) + >>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> # building two policies with different behaviours: + >>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> traj_td = env.rollout(3) # some random temporal data + >>> traj_td = policy_training(traj_td) + >>> # let's check that both return the same results + >>> td_inf = TensorDict({}, traj_td.shape[:-1]) + >>> for td in traj_td.unbind(-1): + ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) + ... td_inf = policy_inference(td_inf) + ... td_inf = step_mdp(td_inf) + ... + >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"]) + """ + if mode is self._recurrent_mode: + return self + out = GRUModule(gru=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) + out._recurrent_mode = mode + return out + + def forward(self, tensordict: TensorDictBase): + # we want to get an error if the value input is missing, but not the hidden states + defaults = [NO_DEFAULT, None] + shape = tensordict.shape + tensordict_shaped = tensordict + if self.recurrent_mode: + # if less than 2 dims, unsqueeze + ndim = tensordict_shaped.get(self.in_keys[0]).ndim + while ndim < 3: + tensordict_shaped = tensordict_shaped.unsqueeze(0) + ndim += 1 + if ndim > 3: + dims_to_flatten = ndim - 3 + # we assume that the tensordict can be flattened like this + nelts = prod(tensordict_shaped.shape[: dims_to_flatten + 1]) + tensordict_shaped = tensordict_shaped.apply( + lambda value: value.flatten(0, dims_to_flatten), + batch_size=[nelts, tensordict_shaped.shape[-1]], + ) + else: + tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) + + is_init = tensordict_shaped.get("is_init").squeeze(-1) + splits = None + if self.recurrent_mode and is_init[..., 1:].any(): + # if we have consecutive trajectories, things get a little more complicated + # we have a tensordict of shape [B, T] + # we will split / pad things such that we get a tensordict of shape + # [N, T'] where T' <= T and N >= B is the new batch size, such that + # each index of N is an independent trajectory. We'll need to keep + # track of the indices though, as we want to put things back together in the end. + splits = _get_num_per_traj_init(is_init) + tensordict_shaped_shape = tensordict_shaped.shape + tensordict_shaped = _split_and_pad_sequence( + tensordict_shaped.select(*self.in_keys, strict=False), splits + ) + is_init = tensordict_shaped.get("is_init").squeeze(-1) + + value, hidden = ( + tensordict_shaped.get(key, default) + for key, default in zip(self.in_keys, defaults) + ) + batch, steps = value.shape[:2] + device = value.device + dtype = value.dtype + # packed sequences do not help to get the accurate last hidden values + # if splits is not None: + # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) + if is_init.any() and hidden is not None: + hidden[is_init] = 0 + val, hidden = self._gru(value, batch, steps, device, dtype, hidden) + tensordict_shaped.set(self.out_keys[0], val) + tensordict_shaped.set(self.out_keys[1], hidden) + if splits is not None: + # let's recover our original shape + tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape( + tensordict_shaped_shape + ) + + if shape != tensordict_shaped.shape or tensordict_shaped is not tensordict: + tensordict.update(tensordict_shaped.reshape(shape)) + return tensordict + + def _gru( + self, + input: torch.Tensor, + batch, + steps, + device, + dtype, + hidden_in: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if not self.recurrent_mode and steps != 1: + raise ValueError("Expected a single step") + + if hidden_in is None: + shape = (batch, steps) + hidden_in = torch.zeros( + *shape, + self.gru.num_layers, + self.gru.hidden_size, + device=device, + dtype=dtype, + ) + + # we only need the first hidden state + _hidden_in = hidden_in[:, 0] + hidden = _hidden_in.transpose(-3, -2).contiguous() + + y, hidden = self.gru(input, hidden) + # dim 0 in hidden is num_layers, but that will conflict with tensordict + hidden = hidden.transpose(0, 1) + + # we pad the hidden states with zero to make tensordict happy + hidden = torch.stack( + [torch.zeros_like(hidden) for _ in range(steps - 1)] + [hidden], + 1, + ) + out = [y, hidden] + return tuple(out) diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index b9b641e23d5..7a7f766a44c 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -25,67 +25,3 @@ def __instancecheck__(self, instance): from .mappings import biased_softplus, inv_softplus, mappings - - -class Buffer(torch.Tensor, metaclass=_ParameterMeta): - r"""A kind of Tensor that is to be considered a module parameter. - - Parameters are :class:`~torch.Tensor` subclasses, that have a - very special property when used with :class:`Module` s - when they're - assigned as Module attributes they are automatically added to the list of - its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. - Assigning a Tensor doesn't have such effect. This is because one might - want to cache some temporary state, like last hidden state of the RNN, in - the model. If there was no such class as :class:`Parameter`, these - temporaries would get registered too. - - Args: - data (Tensor): parameter tensor. - requires_grad (bool, optional): if the parameter requires gradient. See - :ref:`locally-disable-grad-doc` for more details. Default: `True` - """ - - def __new__(cls, data=None, requires_grad=False): - if data is None: - data = torch.empty(0) - if type(data) is torch.Tensor or type(data) is Buffer: - # For ease of BC maintenance, keep this path for standard Tensor. - # Eventually (tm), we should change the behavior for standard Tensor to match. - return torch.Tensor._make_subclass(cls, data, requires_grad) - - # Path for custom tensors: set a flag on the instance to indicate parameter-ness. - t = data.detach().requires_grad_(requires_grad) - if type(t) is not type(data): - raise RuntimeError( - f"Creating a Parameter from an instance of type {type(data).__name__} " - "requires that detach() returns an instance of the same type, but return " - f"type {type(t).__name__} was found instead. To use the type as a " - "Parameter, please correct the detach() semantics defined by " - "its __torch_dispatch__() implementation." - ) - t._is_param = True - return t - - # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types - # are still considered that custom tensor type and these methods will not be called for them. - def __deepcopy__(self, memo): - if id(self) in memo: - return memo[id(self)] - else: - result = type(self)( - self.data.clone(memory_format=torch.preserve_format), self.requires_grad - ) - memo[id(self)] = result - return result - - def __repr__(self): - return "Buffer containing:\n" + super(Buffer, self).__repr__() - - def __reduce_ex__(self, proto): - # See Note [Don't serialize hooks] - return ( - torch._utils._rebuild_parameter, - (self.data, self.requires_grad, OrderedDict()), - ) - - __torch_function__ = _disabled_torch_function_impl diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 95427fce078..e69de29bb2d 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -1,51 +0,0 @@ -from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - TensorSpec, -) - -ACTION_SPACE_MAP = {} -ACTION_SPACE_MAP[OneHotDiscreteTensorSpec] = "one_hot" -ACTION_SPACE_MAP[MultiOneHotDiscreteTensorSpec] = "mult_one_hot" -ACTION_SPACE_MAP[BinaryDiscreteTensorSpec] = "binary" -ACTION_SPACE_MAP[DiscreteTensorSpec] = "categorical" -ACTION_SPACE_MAP["one_hot"] = "one_hot" -ACTION_SPACE_MAP["one-hot"] = "one_hot" -ACTION_SPACE_MAP["mult_one_hot"] = "mult_one_hot" -ACTION_SPACE_MAP["mult-one-hot"] = "mult_one_hot" -ACTION_SPACE_MAP["multi_one_hot"] = "mult_one_hot" -ACTION_SPACE_MAP["multi-one-hot"] = "mult_one_hot" -ACTION_SPACE_MAP["binary"] = "binary" -ACTION_SPACE_MAP["categorical"] = "categorical" -# TODO for the future ;) -# ACTION_SPACE_MAP[MultiDiscreteTensorSpec] = "multi_categorical" -# ACTION_SPACE_MAP["multi_categorical"] = "multi_categorical" -# ACTION_SPACE_MAP["multi-categorical"] = "multi_categorical" -# ACTION_SPACE_MAP["multi_discrete"] = "multi_categorical" -# ACTION_SPACE_MAP["multi-discrete"] = "multi_categorical" - - -def _find_action_space(action_space): - if isinstance(action_space, TensorSpec): - if isinstance(action_space, CompositeSpec): - if "action" in action_space.keys(): - _key = "action" - else: - # the first key is the action - for _key in action_space.keys(True, True): - if isinstance(_key, tuple) and _key[-1] == "action": - break - else: - raise KeyError - action_space = action_space[_key] - action_space = type(action_space) - try: - action_space = ACTION_SPACE_MAP[action_space] - except KeyError: - raise ValueError( - f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}." - ) - return action_space diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 163365bdc75..023b22ba3c4 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -7,6 +7,7 @@ from .common import LossModule from .cql import CQLLoss from .ddpg import DDPGLoss +from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .iql import IQLLoss diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 89314c0603b..bb7b9014f0d 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -97,6 +97,7 @@ class A2CLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -114,7 +115,7 @@ class A2CLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic. The return value is a tuple of tensors in the following order: ``["loss_objective"]`` + ``["loss_critic"]`` if critic_coef is not None @@ -148,6 +149,7 @@ class A2CLoss(LossModule): ... observation = torch.randn(*batch, n_obs), ... action = spec.rand(batch), ... next_done = torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated = torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() @@ -161,6 +163,7 @@ class A2CLoss(LossModule): ... observation = torch.randn(*batch, n_obs), ... action = spec.rand(batch), ... next_done = torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated = torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() @@ -187,6 +190,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -195,6 +201,7 @@ class _AcceptedKeys: action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE @@ -232,12 +239,14 @@ def __init__( self.convert_to_functional(critic, "critic", compare_against=policy_params) self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus and entropy_coef - self.register_buffer( - "entropy_coef", torch.tensor(entropy_coef, device=self.device) - ) - self.register_buffer( - "critic_coef", torch.tensor(critic_coef, device=self.device) - ) + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) + self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device)) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma @@ -249,6 +258,7 @@ def in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor.in_keys, *[("next", key) for key in self.actor.in_keys], ] @@ -280,6 +290,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) def reset(self) -> None: @@ -387,5 +398,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": self.tensor_keys.value_target, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 3b96913747e..bdccbda3808 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -8,40 +8,32 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple -import torch +from tensordict import TensorDictBase from tensordict.nn import ( make_functional, repopulate_module, TensorDictModule, TensorDictModuleBase, + TensorDictParams, ) - -from tensordict.tensordict import TensorDictBase -from torch import nn, Tensor +from torch import nn from torch.nn import Parameter from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules.utils import Buffer -from torchrl.objectives.utils import _cache_values, ValueEstimators +from torchrl.objectives.utils import ValueEstimators from torchrl.objectives.value import ValueEstimatorBase -_has_functorch = False -try: - import functorch as ft # noqa - - _has_functorch = True - FUNCTORCH_ERR = "" -except ImportError: - print( - "failed to import functorch. TorchRL's features that do not require " - "functional programming should work, but functionality and performance " - "may be affected. Consider installing functorch and/or upgrating pytorch." - ) - FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." + +def _updater_check_forward_prehook(module, *args, **kwargs): + if not all(v for v in module._has_update_associated.values()) and RL_WARNINGS: + warnings.warn( + module.TARGET_NET_WARNING, + category=UserWarning, + ) class LossModule(TensorDictModuleBase): @@ -96,6 +88,13 @@ class _AcceptedKeys: default_value_estimator: ValueEstimators = None SEP = "_sep_" + TARGET_NET_WARNING = ( + "No target network updater has been associated " + "with this loss module, but target parameters have been found. " + "While this is supported, it is expected that the target network " + "updates will be manually performed. You can deactivate this warning " + "by turning the RL_WARNINGS env variable to False." + ) @property def tensor_keys(self) -> _AcceptedKeys: @@ -103,7 +102,6 @@ def tensor_keys(self) -> _AcceptedKeys: def __new__(cls, *args, **kwargs): cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) - cls._tensor_keys = cls._AcceptedKeys() self = super().__new__(cls) return self @@ -112,8 +110,10 @@ def __init__(self): self._cache = {} self._param_maps = {} self._value_estimator = None - self._has_update_associated = False + self._has_update_associated = {} self.value_type = self.default_value_estimator + self._tensor_keys = self._AcceptedKeys() + self.register_forward_pre_hook(_updater_check_forward_prehook) # self.register_forward_pre_hook(_parameters_to_tensordict) def _set_deprecated_ctor_keys(self, **kwargs) -> None: @@ -251,16 +251,9 @@ def convert_to_functional( # we transform the buffers in params to make sure they follow the device # as tensor = nn.Parameter(tensor) keeps its identity when moved to another device - def create_buffers(tensor): - - if isinstance(tensor, torch.Tensor) and not isinstance( - tensor, (Buffer, nn.Parameter) - ): - return Buffer(tensor, requires_grad=tensor.requires_grad) - return tensor - # separate params and buffers - params_and_buffers = params_and_buffers.apply(create_buffers) + params_and_buffers = TensorDictParams(params_and_buffers, no_convert=True) + # sanity check for key in params_and_buffers.keys(True): if sep in key: raise KeyError( @@ -269,11 +262,6 @@ def create_buffers(tensor): params_and_buffers_flat = params_and_buffers.flatten_keys(sep) buffers = params_and_buffers_flat.select(*buffer_names) params = params_and_buffers_flat.exclude(*buffer_names) - if expand_dim and not _has_functorch: - raise ImportError( - "expanding params is only possible when functorch is installed," - "as this feature requires calls to the vmap operator." - ) if compare_against is not None: compare_against = set(compare_against) else: @@ -285,7 +273,6 @@ def create_buffers(tensor): # For buffers, a cloned expansion (or equivalently a repeat) is returned. def _compare_and_expand(param): - if param in compare_against: expanded_param = param.data.expand(expand_dim, *param.shape) # the expanded parameter must be sent to device when to() @@ -300,13 +287,12 @@ def _compare_and_expand(param): ) return p_out - params_udpated = params.apply( + params = params.apply( _compare_and_expand, batch_size=[expand_dim, *params.shape] ) - params = params_udpated buffers = buffers.apply( - lambda buffer: Buffer(buffer.expand(expand_dim, *buffer.shape).clone()), + lambda buffer: buffer.expand(expand_dim, *buffer.shape).clone(), batch_size=[expand_dim, *buffers.shape], ) @@ -321,161 +307,45 @@ def _compare_and_expand(param): prev_set_params = set(self.parameters()) # register parameters and buffers - for key, parameter in params.items(): + for key, parameter in list(params_and_buffers.items(True, True)): if parameter not in prev_set_params: - setattr(self, sep.join([module_name, key]), parameter) - else: - # if the parameter is already present, we register a string pointing - # to is instead. If the string ends with a '_detached' suffix, the - # value will be detached - for _param_name, p in self.named_parameters(): - if parameter is p: - break - else: - raise RuntimeError("parameter not found") - if compare_against is not None and p in compare_against: - _param_name = _param_name + "_detached" - setattr(self, sep.join([module_name, key]), _param_name) - prev_set_buffers = set(self.buffers()) - for key, buffer in buffers.items(): - if buffer not in prev_set_buffers: - self.register_buffer(sep.join([module_name, key]), buffer) - else: - for _buffer_name, b in self.named_buffers(): - if buffer is b: - break - else: - raise RuntimeError("buffer not found") - setattr(self, sep.join([module_name, key]), _buffer_name) - - setattr(self, "_" + param_name, params_and_buffers) - setattr( - self.__class__, - param_name, - property(lambda _self=self: _self._param_getter(module_name)), - ) + pass + elif compare_against is not None and parameter in compare_against: + params_and_buffers.set(key, parameter.data) + + setattr(self, param_name, params_and_buffers) # set the functional module setattr(self, module_name, functional_module) - # creates a map nn.Parameter name -> expanded parameter name - for key, value in params.items(True, True): - if not isinstance(key, tuple): - key = (key,) - if not isinstance(value, nn.Parameter): - # find the param name - for name, param in self.named_parameters(): - if param.data.data_ptr() == value.data_ptr() and param is not value: - self._param_maps[name] = sep.join([module_name, *key]) - break - else: - raise RuntimeError(f"key {key} did not find matching param.") - - name_params_target = "_target_" + module_name + name_params_target = "target_" + module_name if create_target_params: - target_params = params_and_buffers.apply(_make_target_param(clone=True)) - target_params_items = target_params.items(True, True) - target_params_list = [] - for (key, val) in target_params_items: - if not isinstance(key, tuple): - key = (key,) - name = sep.join([name_params_target, *key]) - self.register_buffer(name, val) - target_params_list.append((name, key)) - setattr(self, name_params_target + "_params", target_params) - else: - setattr(self, name_params_target + "_params", None) - setattr( - self.__class__, - name_params_target[1:] + "_params", - property(lambda _self=self: _self._target_param_getter(module_name)), - ) - - @_cache_values - def _param_getter(self, network_name): - name = "_" + network_name + "_params" - param_name = network_name + "_params" - if name in self.__dict__: - params = getattr(self, name) - if params is not None: - with params.unlock_(): - # get targets and update - for key in params.keys(True, True): - if not isinstance(key, tuple): - key = (key,) - value_to_set = getattr( - self, self.SEP.join([network_name, *key]) - ) - if isinstance(value_to_set, str): - if value_to_set.endswith("_detached"): - value_to_set = value_to_set[:-9] - value_to_set = getattr(self, value_to_set) - is_param = isinstance(value_to_set, nn.Parameter) - is_buffer = isinstance(value_to_set, Buffer) - value_to_set = value_to_set.detach() - if is_param: - value_to_set = nn.Parameter( - value_to_set, requires_grad=False - ) - elif is_buffer: - value_to_set = Buffer( - value_to_set, requires_grad=False - ) - else: - value_to_set = getattr(self, value_to_set) - # params.set(key, value_to_set) - params._set_tuple( - key, value_to_set, inplace=False, validated=True - ) - return params - else: - params = getattr(self, param_name) - return params.apply(_make_target_param(clone=False)) - - else: - raise RuntimeError( - f"{self.__class__.__name__} does not have the target param {name}" - ) - - @_cache_values - def _target_param_getter(self, network_name): - target_name = "_target_" + network_name + "_params" - param_name = network_name + "_params" - if target_name in self.__dict__: - target_params = getattr(self, target_name) - if target_params is not None: - if not self._has_update_associated and RL_WARNINGS: - warnings.warn( - "No target network updater has been associated " - "with this loss module, but target parameters have been found. " - "While this is supported, it is expected that the target network " - "updates will be manually performed. You can deactivate this warning " - "by turning the RL_WARNINGS env variable to False.", - category=UserWarning, - ) - with target_params.unlock_(): - # get targets and update - for key in target_params.keys(True, True): - if not isinstance(key, tuple): - key = (key,) - value_to_set = getattr( - self, self.SEP.join(["_target_" + network_name, *key]) - ) - # target_params.set(key, value_to_set) - target_params._set_tuple( - key, value_to_set, inplace=False, validated=True - ) - else: - params = getattr(self, param_name) - # should we clone here? - target_params = params.apply(_make_target_param(clone=False)) - - return target_params - - else: - raise RuntimeError( - f"{self.__class__.__name__} does not have the target param {target_name}" + # if create_target_params: + # we create a TensorDictParams to keep the target params as Buffer instances + target_params = TensorDictParams( + params_and_buffers.apply( + _make_target_param(clone=create_target_params) + ), + no_convert=True, ) + setattr(self, name_params_target + "_params", target_params) + self._has_update_associated[module_name] = not create_target_params + + def __getattr__(self, item): + if item.startswith("target_") and item.endswith("_params"): + params = self._modules.get(item, None) + if params is None: + # no target param, take detached data + params = getattr(self, item[7:]) + params = params.data + elif not self._has_update_associated[item[7:-7]] and RL_WARNINGS: + # no updater associated + warnings.warn( + self.TARGET_NET_WARNING, + category=UserWarning, + ) + return params + return super().__getattr__(item) def _apply(self, fn): # any call to apply erases the cache: the reason is that detached @@ -493,18 +363,6 @@ def _networks(self) -> Iterator[nn.Module]: if isinstance(item, nn.Module): yield item - @property - def device(self) -> torch.device: - for p in self.parameters(): - return p.device - return torch.device("cpu") - - def register_buffer( - self, name: str, tensor: Optional[Tensor], persistent: bool = True - ) -> None: - # tensor = tensor.to(self.device) - return super().register_buffer(name, tensor, persistent) - def parameters(self, recurse: bool = True) -> Iterator[Parameter]: for _, param in self.named_parameters(recurse=recurse): yield param @@ -520,34 +378,6 @@ def reset(self) -> None: # mainly used for PPO with KL target pass - def to(self, *args, **kwargs): - # get the names of the parameters to map - out = super().to(*args, **kwargs) - for origin, target in self._param_maps.items(): - origin_value = getattr(self, origin) - target_value = getattr(self, target) - setattr(self, target, origin_value.expand_as(target_value)) - out._cache = {} - return out - - def cuda(self, device: Optional[Union[int, device]] = None) -> LossModule: - if device is None: - return self.to("cuda") - else: - return self.to(device) - - def double(self) -> LossModule: - return self.to(torch.double) - - def float(self) -> LossModule: - return self.to(torch.float) - - def half(self) -> LossModule: - return self.to(torch.half) - - def cpu(self) -> LossModule: - return self.to(torch.device("cpu")) - @property def value_estimator(self) -> ValueEstimatorBase: """The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network.""" @@ -624,6 +454,86 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams else: raise NotImplementedError(f"Unknown value type {value_type}") + # def _apply(self, fn, recurse=True): + # """Modifies torch.nn.Module._apply to work with Buffer class.""" + # if recurse: + # for module in self.children(): + # module._apply(fn) + # + # def compute_should_use_set_data(tensor, tensor_applied): + # if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): + # # If the new tensor has compatible tensor type as the existing tensor, + # # the current behavior is to change the tensor in-place using `.data =`, + # # and the future behavior is to overwrite the existing tensor. However, + # # changing the current behavior is a BC-breaking change, and we want it + # # to happen in future releases. So for now we introduce the + # # `torch.__future__.get_overwrite_module_params_on_conversion()` + # # global flag to let the user control whether they want the future + # # behavior of overwriting the existing tensor or not. + # return not torch.__future__.get_overwrite_module_params_on_conversion() + # else: + # return False + # + # for key, param in self._parameters.items(): + # if param is None: + # continue + # # Tensors stored in modules are graph leaves, and we don't want to + # # track autograd history of `param_applied`, so we have to use + # # `with torch.no_grad():` + # with torch.no_grad(): + # param_applied = fn(param) + # should_use_set_data = compute_should_use_set_data(param, param_applied) + # if should_use_set_data: + # param.data = param_applied + # out_param = param + # else: + # assert isinstance(param, Parameter) + # assert param.is_leaf + # out_param = Parameter(param_applied, param.requires_grad) + # self._parameters[key] = out_param + # + # if param.grad is not None: + # with torch.no_grad(): + # grad_applied = fn(param.grad) + # should_use_set_data = compute_should_use_set_data(param.grad, grad_applied) + # if should_use_set_data: + # assert out_param.grad is not None + # out_param.grad.data = grad_applied + # else: + # assert param.grad.is_leaf + # out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad) + # + # for key, buffer in self._buffers.items(): + # if buffer is None: + # continue + # # Tensors stored in modules are graph leaves, and we don't want to + # # track autograd history of `buffer_applied`, so we have to use + # # `with torch.no_grad():` + # with torch.no_grad(): + # buffer_applied = fn(buffer) + # should_use_set_data = compute_should_use_set_data(buffer, buffer_applied) + # if should_use_set_data: + # buffer.data = buffer_applied + # out_buffer = buffer + # else: + # assert isinstance(buffer, Buffer) + # assert buffer.is_leaf + # out_buffer = Buffer(buffer_applied, buffer.requires_grad) + # self._buffers[key] = out_buffer + # + # if buffer.grad is not None: + # with torch.no_grad(): + # grad_applied = fn(buffer.grad) + # should_use_set_data = compute_should_use_set_data(buffer.grad, grad_applied) + # if should_use_set_data: + # assert out_buffer.grad is not None + # out_buffer.grad.data = grad_applied + # else: + # assert buffer.grad.is_leaf + # out_buffer.grad = grad_applied.requires_grad_(buffer.grad.requires_grad) + + return self + class _make_target_param: def __init__(self, clone): @@ -634,4 +544,4 @@ def __call__(self, x): return nn.Parameter( x.data.clone() if self.clone else x.data, requires_grad=False ) - return Buffer(x.data.clone() if self.clone else x.data) + return x.data.clone() if self.clone else x.data diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 06e8de28bf7..249166a6bd2 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -126,6 +126,7 @@ class CQLLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -145,7 +146,7 @@ class CQLLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "loss_alpha_prime", "alpha", "entropy"]``. @@ -184,6 +185,7 @@ class CQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -197,6 +199,7 @@ class CQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -229,6 +232,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -366,8 +370,19 @@ def target_entropy(self): ) if not isinstance(action_spec, CompositeSpec): action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape target_entropy = -float( - np.prod(action_spec[self.tensor_keys.action].shape) + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) @@ -381,6 +396,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): @@ -420,6 +436,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -437,6 +454,7 @@ def in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 03966fd21a0..1795f785716 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -13,8 +13,8 @@ import torch from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import NestedKey +from tensordict.utils import NestedKey, unravel_key from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -69,6 +69,7 @@ class DDPGLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -88,7 +89,7 @@ class DDPGLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["next_reward", "next_done"]`` + in_keys of the actor_network and value_network. + ``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]`` @@ -117,6 +118,7 @@ class DDPGLoss(LossModule): ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), + ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward() @@ -130,6 +132,7 @@ class DDPGLoss(LossModule): ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), + ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward() @@ -154,6 +157,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ @@ -161,6 +167,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.TD0 @@ -216,6 +223,9 @@ def __init__( self.actor_critic.module[1] = self.value_network self.actor_in_keys = actor_network.in_keys + self.value_exclusive_keys = set(self.value_network.in_keys) - ( + set(self.actor_in_keys) | set(self.actor_network.out_keys) + ) self.loss_function = loss_function @@ -229,18 +239,21 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.state_action_value, reward=self._tensor_keys.reward, done=self._tensor_keys.done, + terminated=self._tensor_keys.terminated, ) self._set_in_keys() def _set_in_keys(self): - keys = [ - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), + in_keys = { + unravel_key(("next", self.tensor_keys.reward)), + unravel_key(("next", self.tensor_keys.done)), + unravel_key(("next", self.tensor_keys.terminated)), *self.actor_in_keys, - *[("next", key) for key in self.actor_in_keys], + *[unravel_key(("next", key)) for key in self.actor_in_keys], *self.value_network.in_keys, - ] - self._in_keys = list(set(keys)) + *[unravel_key(("next", key)) for key in self.value_network.in_keys], + } + self._in_keys = sorted(in_keys, key=str) @property def in_keys(self): @@ -260,41 +273,28 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a priority to items in the tensordict. Args: - tensordict (TensorDictBase): a tensordict with keys ["done", "reward"] and the in_keys of the actor + tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor and value networks. Returns: a tuple of 2 tensors containing the DDPG loss. """ - loss_value, td_error, pred_val, target_value = self._loss_value(tensordict) - td_error = td_error.detach() - td_error = td_error.unsqueeze(tensordict.ndimension()) - if tensordict.device is not None: - td_error = td_error.to(tensordict.device) - tensordict.set( - self.tensor_keys.priority, - td_error, - inplace=True, - ) - loss_actor = self._loss_actor(tensordict) + loss_value, metadata = self.loss_value(tensordict) + loss_actor, metadata_actor = self.loss_actor(tensordict) + metadata.update(metadata_actor) return TensorDict( - source={ - "loss_actor": loss_actor.mean(), - "loss_value": loss_value.mean(), - "pred_value": pred_val.mean().detach(), - "target_value": target_value.mean().detach(), - "pred_value_max": pred_val.max().detach(), - "target_value_max": target_value.max().detach(), - }, + source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, batch_size=[], ) - def _loss_actor( + def loss_actor( self, tensordict: TensorDictBase, - ) -> torch.Tensor: - td_copy = tensordict.select(*self.actor_in_keys).detach() + ) -> [torch.Tensor, dict]: + td_copy = tensordict.select( + *self.actor_in_keys, *self.value_exclusive_keys + ).detach() td_copy = self.actor_network( td_copy, params=self.actor_network_params, @@ -303,12 +303,14 @@ def _loss_actor( td_copy, params=self._cached_detached_value_params, ) - return -td_copy.get(self.tensor_keys.state_action_value) + loss_actor = -td_copy.get(self.tensor_keys.state_action_value) + metadata = {} + return loss_actor.mean(), metadata - def _loss_value( + def loss_value( self, tensordict: TensorDictBase, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( @@ -326,7 +328,24 @@ def _loss_value( pred_val, target_value, loss_function=self.loss_function ) - return loss_value, (pred_val - target_value).pow(2), pred_val, target_value + td_error = (pred_val - target_value).pow(2) + td_error = td_error.detach() + if tensordict.device is not None: + td_error = td_error.to(tensordict.device) + tensordict.set( + self.tensor_keys.priority, + td_error, + inplace=True, + ) + with torch.no_grad(): + metadata = { + "td_error": td_error.mean(), + "pred_value": pred_val.mean(), + "target_value": target_value.mean(), + "target_value_max": target_value.max(), + "pred_value_max": pred_val.max(), + } + return loss_value.mean(), metadata def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: @@ -355,6 +374,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.state_action_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py new file mode 100644 index 00000000000..24f6c184d7d --- /dev/null +++ b/torchrl/objectives/decision_transformer.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from typing import Union + +import torch +from tensordict.nn import dispatch +from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.utils import NestedKey + +from torch import distributions as d +from torchrl.modules import ProbabilisticActor + +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import distance_loss + + +class OnlineDTLoss(LossModule): + r"""TorchRL implementation of the Online Decision Transformer loss. + + Presented in `"Online Decision Transformer" ` + + Args: + actor_network (ProbabilisticActor): stochastic actor + + Keyword Args: + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is None (no minimum value). + max_alpha (float, optional): max value of alpha. + Default is None (no maximum value). + fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its + initial value. Otherwise, alpha will be optimized to + match the 'target_entropy' value. + Default is ``False``. + target_entropy (float or str, optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-prod(n_actions)`. + samples_mc_entropy (int): number of samples to estimate the entropy + + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + + """ + + action: NestedKey = "action" + + default_keys = _AcceptedKeys() + + def __init__( + self, + actor_network: ProbabilisticActor, + *, + alpha_init: float = 1.0, + min_alpha: float = None, + max_alpha: float = None, + fixed_alpha: bool = False, + target_entropy: Union[str, float] = "auto", + samples_mc_entropy: int = 1, + ) -> None: + self._in_keys = None + self._out_keys = None + super().__init__() + + # Actor Network + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + funs_to_decorate=["forward", "get_dist"], + ) + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + if bool(min_alpha) ^ bool(max_alpha): + min_alpha = min_alpha if min_alpha else 0.0 + if max_alpha == 0: + raise ValueError("max_alpha must be either None or greater than 0.") + max_alpha = max_alpha if max_alpha else 1e9 + if min_alpha: + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + else: + self.min_log_alpha = None + if max_alpha: + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + else: + self.max_log_alpha = None + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = actor_network.spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = actor_network.spec.shape + target_entropy = -float( + actor_network.spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() + ) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + + self.samples_mc_entropy = samples_mc_entropy + self._set_in_keys() + + def _set_in_keys(self): + keys = self.actor_network.in_keys + keys = set(keys) + keys.add(self.tensor_keys.action) + self._in_keys = sorted(keys, key=str) + + def _forward_value_estimator_keys(self, **kwargs): + pass + + @property + def alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = [ + "loss_log_likelihood", + "loss_entropy", + "loss_alpha", + "alpha", + "entropy", + ] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: + x = dist.rsample((self.samples_mc_entropy,)) + log_p = dist.log_prob(x) + # log_p: (batch_size, context_len) + return -log_p.mean(axis=0) + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Compute the loss for the Online Decision Transformer.""" + # extract action targets + target_actions = tensordict.get(self.tensor_keys.action).detach() + + action_dist = self.actor_network.get_dist( + tensordict, params=self.actor_network_params + ) + + log_likelihood = action_dist.log_prob(target_actions).mean() + entropy = self.get_entropy_bonus(action_dist).mean() + entropy_bonus = self.alpha.detach() * entropy + + loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() + + out = { + "loss_log_likelihood": -log_likelihood, + "loss_entropy": -entropy_bonus, + "loss_alpha": loss_alpha, + "entropy": entropy.detach(), + "alpha": self.alpha.detach(), + } + return TensorDict(out, []) + + +class DTLoss(LossModule): + r"""TorchRL implementation of the Online Decision Transformer loss. + + Presented in `"Decision Transformer: Reinforcement Learning via Sequence Modeling" ` + + Args: + actor_network (ProbabilisticActor): stochastic actor + + Keyword Args: + loss_function (str): loss function to use. Defaults to ``"l2"``. + + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + """ + + action: NestedKey = "action" + + default_keys = _AcceptedKeys() + + def __init__( + self, + actor_network: ProbabilisticActor, + *, + loss_function: str = "l2", + ) -> None: + self._in_keys = None + self._out_keys = None + super().__init__() + + # Actor Network + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + funs_to_decorate=["forward"], + ) + self.loss_function = loss_function + + def _set_in_keys(self): + keys = self.actor_network.in_keys + keys = set(keys) + keys.add(self.tensor_keys.action) + self._in_keys = sorted(keys, key=str) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Compute the loss for the Online Decision Transformer.""" + # extract action targets + target_actions = tensordict.get(self.tensor_keys.action).detach() + + pred_actions = self.actor_network( + tensordict, params=self.actor_network_params + ).get(self.tensor_keys.action) + loss = distance_loss( + pred_actions, + target_actions, + loss_function=self.loss_function, + ).mean() + out = { + "loss": loss, + } + return TensorDict(out, []) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 02b82ff430c..696efbdc650 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -109,6 +109,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -118,6 +121,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() delay_actor: bool = False @@ -248,6 +252,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -264,6 +269,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -434,6 +440,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index d740d45507e..225d5d553bd 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -13,6 +13,8 @@ from torch import nn from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.utils import _find_action_space + from torchrl.envs.utils import step_mdp from torchrl.modules.tensordict_module.actors import ( DistributionalQValueActor, @@ -20,8 +22,6 @@ ) from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible -from torchrl.modules.utils.utils import _find_action_space - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_WARNING, @@ -71,6 +71,7 @@ class DQNLoss(LossModule): ... "action": spec.rand(batch), ... ("next", "observation"): torch.randn(*batch, n_obs), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1) ... }, batch) >>> loss(data) @@ -84,7 +85,7 @@ class DQNLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["observation", "next_observation", "action", "next_reward", "next_done"]``, + ``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``, and a single loss value is returned. Examples: @@ -103,11 +104,13 @@ class DQNLoss(LossModule): >>> action = action_spec.rand() >>> next_reward = torch.randn(1) >>> next_done = torch.zeros(1, dtype=torch.bool) + >>> next_terminated = torch.zeros(1, dtype=torch.bool) >>> loss_val = dqn_loss( ... observation=observation, ... next_observation=next_observation, ... next_reward=next_reward, ... next_done=next_done, + ... next_terminated=next_terminated, ... action=action) """ @@ -137,6 +140,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -147,6 +153,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -212,6 +219,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -220,6 +228,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.value_network.in_keys, *[("next", key) for key in self.value_network.in_keys], ] @@ -260,6 +269,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -272,29 +282,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: Args: tensordict (TensorDictBase): a tensordict with keys ["action"] and the in_keys of - the value network (observations, "done", "reward" in a "next" tensordict). + the value network (observations, "done", "terminated", "reward" in a "next" tensordict). Returns: a tensor containing the DQN loss. """ - if self.device is not None: - warnings.warn( - "The use of a device for the objective function will soon be deprecated", - category=DeprecationWarning, - ) - device = self.device - else: - device = tensordict.device - tddevice = tensordict.to(device) - - td_copy = tddevice.clone(False) + td_copy = tensordict.clone(False) self.value_network( td_copy, params=self.value_network_params, ) - action = tddevice.get(self.tensor_keys.action) + action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get(self.tensor_keys.action_value) if self.action_space == "categorical": @@ -373,6 +373,8 @@ class _AcceptedKeys: Defaults to ``"reward"``. done (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. Defaults to ``"done"``. + terminated (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. + Defaults to ``"terminated"``. steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is exptected. Defaults to ``"steps_to_next_obs"``. """ @@ -382,6 +384,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" default_keys = _AcceptedKeys() @@ -433,10 +436,9 @@ def _log_ps_a_categorical(action, action_log_softmax): def forward(self, input_tensordict: TensorDictBase) -> TensorDict: # from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py - device = self.device tensordict = TensorDict( source=input_tensordict, batch_size=input_tensordict.batch_size - ).to(device) + ) if tensordict.batch_dims != 1: raise RuntimeError( @@ -453,6 +455,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: action = tensordict.get(self.tensor_keys.action) reward = tensordict.get(("next", self.tensor_keys.reward)) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, 1) discount = self.gamma**steps_to_next_obs @@ -500,12 +503,13 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: # Tz = R^n + (γ^n)z (accounting for terminal states) if isinstance(discount, torch.Tensor): discount = discount.to("cpu") - done = done.to("cpu") + # done = done.to("cpu") + terminated = terminated.to("cpu") reward = reward.to("cpu") support = support.to("cpu") pns_a = pns_a.to("cpu") - Tz = reward + (1 - done.to(reward.dtype)) * discount * support + Tz = reward + (1 - terminated.to(reward.dtype)) * discount * support if Tz.shape != torch.Size([batch_size, atoms]): raise RuntimeError( "Tz shape must be torch.Size([batch_size, atoms]), " @@ -543,7 +547,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: m.view(-1).index_add_(0, index, tensor) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) - loss = -torch.sum(m.to(device) * log_ps_a, 1) + loss = -torch.sum(m.to(input_tensordict.device) * log_ps_a, 1) input_tensordict.set( self.tensor_keys.priority, loss.detach().unsqueeze(1).to(input_tensordict.device), diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index a9e546a72ec..7bdfde573fa 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -216,12 +216,15 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"state_value"``. done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected ("next", done). Defaults to ``"done"``. + terminated (NestedKey): The input tensordict key where the flag if a + trajectory is terminated is expected ("next", terminated). Defaults to ``"terminated"``. """ belief: NestedKey = "belief" reward: NestedKey = "reward" value: NestedKey = "state_value" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TDLambda @@ -286,7 +289,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: if self.discount_loss: gamma = self.value_estimator.gamma.to(tensordict.device) - discount = gamma.expand(lambda_target.shape) + discount = gamma.expand(lambda_target.shape).clone() discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) actor_loss = -(lambda_target * discount).sum((-2, -1)).mean() @@ -297,11 +300,13 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) + terminated = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) input_tensordict = TensorDict( { ("next", self.tensor_keys.reward): reward, ("next", self.tensor_keys.value): value, ("next", self.tensor_keys.done): done, + ("next", self.tensor_keys.terminated): terminated, }, [], ) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 6ffff97c66a..966550e21e5 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -103,6 +103,7 @@ class IQLLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -120,7 +121,7 @@ class IQLLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``. @@ -163,6 +164,7 @@ class IQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -177,6 +179,7 @@ class IQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -206,6 +209,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ value: NestedKey = "state_value" @@ -215,6 +221,7 @@ class _AcceptedKeys: state_action_value: NestedKey = "state_action_value" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -307,6 +314,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -336,6 +344,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -490,5 +499,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index e9eca7ce293..00106571744 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -18,12 +18,12 @@ from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.utils import _find_action_space + from torchrl.modules import SafeSequential from torchrl.modules.tensordict_module.actors import QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible -from torchrl.modules.utils.utils import _find_action_space - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -120,6 +120,7 @@ class QMixerLoss(LossModule): ... "state": torch.zeros(32, 64, 64, 3), ... "reward": torch.zeros(32, 1), ... "done": torch.zeros(32, 1, dtype=torch.bool), + ... "terminated": torch.zeros(32, 1, dtype=torch.bool), ... }, ... [32], ... ), @@ -162,6 +163,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -173,6 +177,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -260,6 +265,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.global_value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -268,6 +274,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.global_value_network.in_keys, *[("next", key) for key in self.global_value_network.in_keys], ] @@ -312,6 +319,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.global_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 3cd61922dc6..e576ca33c1c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -144,11 +144,11 @@ class PPOLoss(LossModule): >>> loss = PPOLoss(actor, value) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> data = TensorDict({ - ... "observation": torch.randn(*batch, n_obs), + >>> data = TensorDict({"observation": torch.randn(*batch, n_obs), ... "action": action, ... "sample_log_prob": torch.randn_like(action[..., 1]), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -166,7 +166,7 @@ class PPOLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "sample_log_prob", "next_reward", "next_done"]`` + in_keys of the actor and value network. + ``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network. The return value is a tuple of tensors in the following order: ``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coef is not None. @@ -204,6 +204,7 @@ class PPOLoss(LossModule): ... action=action, ... sampleLogProb=torch.randn_like(action[..., 1]) / 10, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward() @@ -233,6 +234,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -242,6 +246,7 @@ class _AcceptedKeys: action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -279,12 +284,14 @@ def __init__( self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus self.separate_losses = separate_losses - self.register_buffer( - "entropy_coef", torch.tensor(entropy_coef, device=self.device) - ) - self.register_buffer( - "critic_coef", torch.tensor(critic_coef, device=self.device) - ) + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) + self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device)) self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage if gamma is not None: @@ -302,6 +309,7 @@ def _set_in_keys(self): self.tensor_keys.sample_log_prob, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor.in_keys, *[("next", key) for key in self.actor.in_keys], *self.critic.in_keys, @@ -341,6 +349,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -469,6 +478,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": self.tensor_keys.value_target, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -639,11 +649,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() batch = log_weight.shape[0] - if not advantage.shape == log_weight.shape: - raise RuntimeError( - f"advantage.shape and log_weight.shape do not match (got {advantage.shape} " - f"and {log_weight.shape})" - ) gain1 = log_weight.exp() * advantage log_weight_clip = log_weight.clamp(*self._clip_bounds) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 039b5c65b9d..dd64a4bc033 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -8,7 +8,6 @@ from numbers import Number from typing import Union -import numpy as np import torch from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential @@ -124,6 +123,7 @@ class REDQLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -146,7 +146,7 @@ class REDQLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``. @@ -187,6 +187,7 @@ class REDQLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward() @@ -215,6 +216,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -224,6 +228,7 @@ class _AcceptedKeys: state_action_value: NestedKey = "state_action_value" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() delay_actor: bool = False @@ -352,8 +357,19 @@ def target_entropy(self): ) if not isinstance(action_spec, CompositeSpec): action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape target_entropy = -float( - np.prod(action_spec[self.tensor_keys.action].shape) + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) @@ -367,6 +383,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -383,6 +400,7 @@ def _set_in_keys(self): self.tensor_keys.sample_log_prob, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -602,5 +620,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 7c314bace36..93910f1eebf 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -98,6 +98,7 @@ class ReinforceLoss(LossModule): ... "observation": torch.randn(batch, n_obs), ... "reward": torch.randn(batch, 1), ... "done": torch.zeros(batch, 1, dtype=torch.bool), + ... "terminated": torch.zeros(batch, 1, dtype=torch.bool), ... }, ... "action": torch.randn(batch, n_act), ... }, [batch]) @@ -113,7 +114,7 @@ class ReinforceLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value"]``. Examples: @@ -141,6 +142,7 @@ class ReinforceLoss(LossModule): ... next_observation=torch.randn(batch, n_obs), ... next_reward=torch.randn(batch, 1), ... next_done=torch.zeros(batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(batch, 1, dtype=torch.bool), ... action=torch.randn(batch, n_act),) >>> loss_actor.backward() @@ -169,6 +171,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -178,6 +183,7 @@ class _AcceptedKeys: action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -241,6 +247,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -249,6 +256,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.critic.in_keys, @@ -341,5 +349,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": self.tensor_keys.value_target, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index aeb9adbafea..076df1c54a4 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -5,19 +5,20 @@ import math import warnings from dataclasses import dataclass +from functools import wraps from numbers import Number -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import torch + from tensordict.nn import dispatch, make_functional, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor - -from torchrl.data import CompositeSpec -from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp - +from torchrl.data import CompositeSpec, TensorSpec +from torchrl.data.utils import _find_action_space +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule @@ -43,6 +44,15 @@ FUNCTORCH_ERROR = err +def _delezify(func): + @wraps(func) + def new_func(self, *args, **kwargs): + self.target_entropy + return func(self, *args, **kwargs) + + return new_func + + class SACLoss(LossModule): """TorchRL implementation of the SAC loss. @@ -137,6 +147,7 @@ class SACLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -156,7 +167,7 @@ class SACLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + ``"loss_value"`` if version one is used. @@ -199,6 +210,7 @@ class SACLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -212,6 +224,7 @@ class SACLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -240,6 +253,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -249,6 +265,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -364,7 +381,6 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec - self.target_entropy_buffer = None if self._version == 1: self.actor_critic = ActorCriticWrapper( self.actor_network, self.value_network @@ -377,37 +393,54 @@ def __init__( if self._version == 1: self._vmap_qnetwork00 = vmap(qvalue_network) + @property + def target_entropy_buffer(self): + return self.target_entropy + @property def target_entropy(self): - target_entropy = self.target_entropy_buffer - if target_entropy is None: - delattr(self, "target_entropy_buffer") - target_entropy = self._target_entropy - action_spec = self._action_spec - actor_network = self.actor_network - device = next(self.parameters()).device - if target_entropy == "auto": - action_spec = ( - action_spec - if action_spec is not None - else getattr(actor_network, "spec", None) - ) - if action_spec is None: - raise RuntimeError( - "Cannot infer the dimensionality of the action. Consider providing " - "the target entropy explicitely or provide the spec of the " - "action tensor in the actor network." - ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) - target_entropy = -float( - np.prod(action_spec[self.tensor_keys.action].shape) + target_entropy = self._buffers.get("_target_entropy", None) + if target_entropy is not None: + return target_entropy + target_entropy = self._target_entropy + action_spec = self._action_spec + actor_network = self.actor_network + device = next(self.parameters()).device + if target_entropy == "auto": + action_spec = ( + action_spec + if action_spec is not None + else getattr(actor_network, "spec", None) + ) + if action_spec is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." ) - self.register_buffer( - "target_entropy_buffer", torch.tensor(target_entropy, device=device) + if not isinstance(action_spec, CompositeSpec): + action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + + action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape + else: + action_container_shape = action_spec.shape + target_entropy = -float( + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() ) - return self.target_entropy_buffer - return target_entropy + delattr(self, "_target_entropy") + self.register_buffer( + "_target_entropy", torch.tensor(target_entropy, device=device) + ) + return self._target_entropy + + state_dict = _delezify(LossModule.state_dict) + load_state_dict = _delezify(LossModule.load_state_dict) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -415,6 +448,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -460,6 +494,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -476,6 +511,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -516,18 +552,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - device = self.device - td_device = tensordict_reshape.to(device) - if self._version == 1: - loss_qvalue, priority = self._loss_qvalue_v1(td_device) - loss_value = self._loss_value(td_device) + loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape) + loss_value, _ = self._value_loss(tensordict_reshape) else: - loss_qvalue, priority = self._loss_qvalue_v2(td_device) + loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape) loss_value = None - loss_actor = self._loss_actor(td_device) - loss_alpha = self._loss_alpha(td_device) - tensordict_reshape.set(self.tensor_keys.priority, priority) + loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape ): @@ -536,12 +569,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) if shape: tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"].mean() out = { "loss_actor": loss_actor.mean(), "loss_qvalue": loss_qvalue.mean(), "loss_alpha": loss_alpha.mean(), "alpha": self._alpha, - "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(), + "entropy": entropy, } if self._version == 1: out["loss_value"] = loss_value.mean() @@ -552,7 +586,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() - def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: + def _actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: with set_exploration_type(ExplorationType.RANDOM): dist = self.actor_network.get_dist( tensordict, @@ -564,7 +600,8 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: td_q = tensordict.select(*self.qvalue_network.in_keys) td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( - td_q, self._cached_detached_qvalue_params # should we clone? + td_q, + self._cached_detached_qvalue_params, # should we clone? ) min_q_logprob = ( td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) @@ -575,9 +612,7 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" ) - # write log_prob in tensordict for alpha loss - tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) - return self._alpha * log_prob - min_q_logprob + return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} @property @_cache_values @@ -593,7 +628,9 @@ def _cached_target_params_actor_value(self): _run_checks=False, ) - def _loss_qvalue_v1(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def _qvalue_v1_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: target_params = self._cached_target_params_actor_value with set_exploration_type(ExplorationType.MODE): target_value = self.value_estimator.value_estimate( @@ -608,26 +645,27 @@ def _loss_qvalue_v1(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: f"Batch size={tensordict.shape} is incompatible " f"with num_qvqlue_nets={self.num_qvalue_nets}." ) - tensordict_chunks = torch.stack( - tensordict.chunk(self.num_qvalue_nets, dim=0), 0 + tensordict_chunks = tensordict.reshape( + self.num_qvalue_nets, -1, *tensordict.shape[1:] + ) + target_chunks = target_value.reshape( + self.num_qvalue_nets, -1, *target_value.shape[1:] ) - target_chunks = torch.stack(target_value.chunk(self.num_qvalue_nets, dim=0), 0) # if vmap=True, it is assumed that the input tensordict must be cast to the param shape tensordict_chunks = self._vmap_qnetwork00( tensordict_chunks, self.qvalue_network_params ) - pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value).squeeze( - -1 - ) + pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value) + pred_val = pred_val.squeeze(-1) loss_value = distance_loss( pred_val, target_chunks, loss_function=self.loss_function ).view(*shape) - priority_value = torch.cat((pred_val - target_chunks).pow(2).unbind(0), 0) + metadata = {"td_error": (pred_val - target_chunks).pow(2).flatten(0, 1)} - return loss_value, priority_value + return loss_value, metadata - def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): + def _compute_target_v2(self, tensordict) -> Tensor: r"""Value network for SAC v2. SAC v2 is based on a value estimate of the form: @@ -645,14 +683,16 @@ def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): with set_exploration_type(ExplorationType.RANDOM): next_tensordict = tensordict.get("next").clone(False) next_dist = self.actor_network.get_dist( - next_tensordict, params=actor_params + next_tensordict, params=self.actor_network_params ) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) # get q-values - next_tensordict_expand = self._vmap_qnetworkN0(next_tensordict, qval_params) + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) state_action_value = next_tensordict_expand.get( self.tensor_keys.state_action_value ) @@ -661,7 +701,7 @@ def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): != next_sample_log_prob.shape ): next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) - next_state_value = state_action_value - _alpha * next_sample_log_prob + next_state_value = state_action_value - self._alpha * next_sample_log_prob next_state_value = next_state_value.min(0)[0] tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value @@ -669,14 +709,11 @@ def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def _loss_qvalue_v2(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def _qvalue_v2_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. - target_value = self._get_value_v2( - tensordict, - self._alpha, - self.actor_network_params, - self.target_qvalue_network_params, - ) + target_value = self._compute_target_v2(tensordict) tensordict_expand = self._vmap_qnetworkN0( tensordict.select(*self.qvalue_network.in_keys), @@ -690,10 +727,13 @@ def _loss_qvalue_v2(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ).mean(0) - return loss_qval, td_error.detach().max(0)[0] + ).sum(0) + metadata = {"td_error": td_error.detach().max(0)[0]} + return loss_qval, metadata - def _loss_value(self, tensordict: TensorDictBase) -> Tensor: + def _value_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( @@ -729,16 +769,16 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tensor: loss_value = distance_loss( pred_val, target_val, loss_function=self.loss_function ) - return loss_value + return loss_value, {} + + def _alpha_loss(self, log_prob: Tensor) -> Tensor: - def _loss_alpha(self, tensordict: TensorDictBase) -> Tensor: - log_pi = tensordict.get(self.tensor_keys.log_prob) if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) else: # placeholder - alpha_loss = torch.zeros_like(log_pi) + alpha_loss = torch.zeros_like(log_prob) return alpha_loss @property @@ -756,9 +796,15 @@ class DiscreteSACLoss(LossModule): Args: actor_network (ProbabilisticActor): the actor to be trained qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. - num_actions (int): number of actions in the action space. + action_space (str or TensorSpec): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + num_actions (int, optional): number of actions in the action space. + To be provided if target_entropy is ste to "auto". num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. - loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", "l1", Default is "smooth_l1". alpha_init (float, optional): initial entropy multiplier. Default is 1.0. @@ -788,61 +834,53 @@ class DiscreteSACLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) - >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) + >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["logits"], ... out_keys=["action"], ... spec=spec, ... distribution_class=OneHotCategorical) - >>> class ValueClass(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.linear = nn.Linear(n_obs, n_act) - ... def forward(self, obs): - ... return self.linear(obs) - >>> module = ValueClass() - >>> qvalue = ValueOperator( - ... module=module, - ... in_keys=['observation']) - >>> loss = DiscreteSACLoss(actor, qvalue, num_actions=actor.spec["action"].space.n) - >>> batch = [2, ] + >>> qvalue = TensorDictModule( + ... nn.Linear(n_obs, n_act), + ... in_keys=["observation"], + ... out_keys=["action_value"], + ... ) + >>> loss = DiscreteSACLoss(actor, qvalue, action_space=spec, num_actions=spec.space.n) + >>> batch = [2,] >>> action = spec.rand(batch) >>> data = TensorDict({ - ... "observation": torch.randn(*batch, n_obs), - ... "action": action, - ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), - ... ("next", "reward"): torch.randn(*batch, 1), - ... ("next", "observation"): torch.randn(*batch, n_obs), + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( - fields={ - action_log_prob_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - next.state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) + fields={ + alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", - "alpha", "entropy", "state_action_value_actor", - "action_log_prob_actor", "next.state_value", "target_value"]`` + "alpha", "entropy"]`` The output keys can also be filtered using :meth:`DiscreteSACLoss.select_out_keys` method. Examples: @@ -883,6 +921,7 @@ class DiscreteSACLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -907,13 +946,19 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" value: NestedKey = "state_value" + action_value: NestedKey = "action_value" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" + log_prob: NestedKey = "log_prob" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -924,18 +969,15 @@ class _AcceptedKeys: "loss_alpha", "alpha", "entropy", - "state_action_value_actor", - "action_log_prob_actor", - "next.state_value", - "target_value", ] def __init__( self, actor_network: ProbabilisticActor, qvalue_network: TensorDictModule, - num_actions: int, # replace with spec? *, + action_space: Union[str, TensorSpec] = None, + num_actions: Optional[int] = None, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -958,7 +1000,7 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist_params"], + funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -1011,22 +1053,24 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), ) + if action_space is None: + warnings.warn( + "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." + "This behaviour will be deprecated soon and a space will have to be passed." + "Check the DiscreteSACLoss documentation to see how to pass the action space. " + ) + action_space = "one-hot" + self.action_space = _find_action_space(action_space) if target_entropy == "auto": + if num_actions is None: + raise ValueError( + "num_actions needs to be provided if target_entropy == 'auto'" + ) target_entropy = -float(np.log(1.0 / num_actions) * target_entropy_weight) self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) - - self._vmap_getdist = vmap(self.actor_network.get_dist_params) - self._vmap_qnetwork = vmap(self.qvalue_network) - - @property - def alpha(self): - if self.min_log_alpha is not None: - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) - with torch.no_grad(): - alpha = self.log_alpha.exp() - return alpha + self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0)) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1034,6 +1078,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -1042,6 +1087,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -1060,172 +1106,172 @@ def in_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - obs_keys = self.actor_network.in_keys - tensordict_select = tensordict.clone(False).select( - "next", *obs_keys, self.tensor_keys.action - ) + shape = None + if tensordict.ndimension() > 1: + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict - actor_params = torch.stack( - [self.actor_network_params, self.target_actor_network_params], 0 + loss_value, metadata_value = self._value_loss(tensordict_reshape) + loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_alpha = self._alpha_loss( + log_prob=metadata_actor["log_prob"], ) - tensordict_actor_grad = tensordict_select.select( - *obs_keys - ) # to avoid overwriting keys - next_td_actor = step_mdp(tensordict_select).select( - *self.actor_network.in_keys - ) # next_observation -> - tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - tensordict_actor = tensordict_actor.contiguous() - - with set_exploration_type(ExplorationType.RANDOM): - # vmap doesn't support sampling, so we take it out from the vmap - td_params = self._vmap_getdist( - tensordict_actor, - actor_params, + tensordict_reshape.set(self.tensor_keys.priority, metadata_value["td_error"]) + if loss_actor.shape != loss_value.shape: + raise RuntimeError( + f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}" ) - if isinstance(self.actor_network, ProbabilisticActor): - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - else: - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - probs = tensordict_actor_dist.probs - z = (probs == 0.0).float() * 1e-8 - logp_pi = torch.log(probs + z) - logp_pi_pol = torch.sum(probs * logp_pi, dim=-1, keepdim=True) - - # repeat tensordict_actor to match the qvalue size - _actor_loss_td = ( - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) - ) # for actor loss - _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( - self.num_qvalue_nets, - *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, - ) # for qvalue loss - _next_val_td = ( - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) - ) # for next value estimation - tensordict_qval = torch.cat( - [ - _actor_loss_td, - _next_val_td, - _qval_td, - ], - 0, - ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"].mean() + out = { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_value.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self._alpha, + "entropy": entropy, + } + return TensorDict(out, []) - # cat params - q_params_detach = self.qvalue_network_params.detach() - qvalue_params = torch.cat( - [ - q_params_detach, - self.target_qvalue_network_params, - self.qvalue_network_params, - ], - 0, - ) - tensordict_qval = self._vmap_qnetwork( - tensordict_qval, - qvalue_params, - ) + def _compute_target(self, tensordict) -> Tensor: + r"""Value network for SAC v2. - state_action_value = tensordict_qval.get(self.tensor_keys.value).squeeze(-1) - ( - state_action_value_actor, - next_state_action_value_qvalue, - state_action_value_qvalue, - ) = state_action_value.split( - [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], - dim=0, - ) + SAC v2 is based on a value estimate of the form: - loss_actor = -( - (state_action_value_actor.min(0)[0] * probs[0]).sum(-1, keepdim=True) - - self.alpha * logp_pi_pol[0] - ).mean() + .. math:: - pred_next_val = ( - probs[1] - * (next_state_action_value_qvalue.min(0)[0] - self.alpha * logp_pi[1]) - ).sum(dim=-1, keepdim=True) + V = Q(s,a) - \alpha * \log p(a | s) - tensordict_select.set( - ("next", self.value_estimator.tensor_keys.value), pred_next_val - ) - target_value = self.value_estimator.value_estimate(tensordict_select).squeeze( - -1 - ) + This class computes this value given the actor and qvalue network - actions = torch.argmax(tensordict_select.get(self.tensor_keys.action), dim=-1) + """ + tensordict = tensordict.clone(False) + # get actions and log-probs + with torch.no_grad(): + next_tensordict = tensordict.get("next").clone(False) - pred_val_1 = ( - state_action_value_qvalue[0].gather(-1, actions.unsqueeze(-1)).unsqueeze(0) - ) - pred_val_2 = ( - state_action_value_qvalue[1].gather(-1, actions.unsqueeze(-1)).unsqueeze(0) - ) - pred_val = torch.cat([pred_val_1, pred_val_2], dim=0).squeeze() - td_error = (pred_val - target_value.expand_as(pred_val)).pow(2) - loss_qval = ( - distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, + # get probs and log probs for actions computed from "next" + next_dist = self.actor_network.get_dist( + next_tensordict, params=self.actor_network_params ) - .mean(-1) - .sum() - * 0.5 - ) + next_prob = next_dist.probs + next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) - tensordict.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) - loss_alpha = self._loss_alpha(logp_pi_pol) - if not loss_qval.shape == loss_actor.shape: - raise RuntimeError( - f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) + + tensordict.set( + ("next", self.value_estimator.tensor_keys.value), next_state_value ) - td_out = TensorDict( - { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), - "loss_alpha": loss_alpha.mean(), - "alpha": self.alpha.detach(), - "entropy": -logp_pi.mean().detach(), - "state_action_value_actor": state_action_value_actor.mean().detach(), - "action_log_prob_actor": logp_pi.mean().detach(), - "next.state_value": pred_next_val.mean().detach(), - "target_value": target_value.mean().detach(), - }, - [], + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + return target_value + + def _value_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + target_value = self._compute_target(tensordict) + tensordict_expand = self._vmap_qnetworkN0( + tensordict.select(*self.qvalue_network.in_keys), + self.qvalue_network_params, ) - return td_out + action_value = tensordict_expand.get(self.tensor_keys.action_value) + action = tensordict.get(self.tensor_keys.action) + action = action.expand((action_value.shape[0], *action.shape)) # Add vmap dim + + # TODO this block comes from the dqn loss, we need to swap all these with a proper + # helper function which selects the value given the action for all discrete spaces + if self.action_space == "categorical": + if action.shape != action_value.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + chosen_action_value = torch.gather(action_value, -1, index=action).squeeze( + -1 + ) + else: + action = action.to(torch.float) + chosen_action_value = (action_value * action).sum(-1) - def _loss_alpha(self, log_pi: Tensor) -> Tensor: - if torch.is_grad_enabled() and not log_pi.requires_grad: + td_error = torch.abs(chosen_action_value - target_value) + loss_qval = distance_loss( + chosen_action_value, + target_value.expand_as(chosen_action_value), + loss_function=self.loss_function, + ).mean(0) + + metadata = { + "td_error": td_error.detach().max(0)[0], + } + return loss_qval, metadata + + def _actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + # get probs and log probs for actions + dist = self.actor_network.get_dist( + tensordict, + params=self.actor_network_params, + ) + prob = dist.probs + log_prob = torch.log(torch.where(prob == 0, 1e-8, prob)) + + td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q = self._vmap_qnetworkN0( + td_q, self._cached_detached_qvalue_params # should we clone? + ) + min_q = td_q.get(self.tensor_keys.action_value).min(0)[0] + + if log_prob.shape != min_q.shape: raise RuntimeError( - "expected log_pi to require gradient for the alpha loss)" + f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}" ) + + # like in continuous SAC, we take the entropy term and subtract the minimum of the value ensemble + loss = self._alpha * log_prob - min_q + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + loss = (prob * loss).sum(-1) + + return loss, {"log_prob": (log_prob * prob).sum(-1).detach()} + + def _alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) else: # placeholder - alpha_loss = torch.zeros_like(log_pi) + alpha_loss = torch.zeros_like(log_prob) return alpha_loss + @property + def _alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + @property + @_cache_values + def _cached_detached_qvalue_params(self): + return self.qvalue_network_params.detach() + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator self.value_type = value_type - value_net = None hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) if hasattr(self, "gamma"): @@ -1233,12 +1279,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams if value_type is ValueEstimators.TD1: self._value_estimator = TD1Estimator( **hp, - value_network=value_net, + value_network=None, ) elif value_type is ValueEstimators.TD0: self._value_estimator = TD0Estimator( **hp, - value_network=value_net, + value_network=None, ) elif value_type is ValueEstimators.GAE: raise NotImplementedError( @@ -1247,7 +1293,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type is ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator( **hp, - value_network=value_net, + value_network=None, ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -1257,5 +1303,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": "value_target", "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 62f0e793f29..9912c143ae6 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -109,6 +109,7 @@ class TD3Loss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -128,7 +129,7 @@ class TD3Loss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. @@ -162,6 +163,7 @@ class TD3Loss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward() @@ -187,6 +189,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -194,6 +199,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -270,13 +276,24 @@ def __init__( ) elif action_spec is not None: if isinstance(action_spec, CompositeSpec): - action_spec = action_spec[self.tensor_keys.action] + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape + action_spec = action_spec[self.tensor_keys.action][ + (0,) * len(action_container_shape) + ] if not isinstance(action_spec, BoundedTensorSpec): raise ValueError( f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." ) - low = action_spec.space.minimum - high = action_spec.space.maximum + low = action_spec.space.low + high = action_spec.space.high else: low, high = bounds if not isinstance(low, torch.Tensor): @@ -302,6 +319,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.state_action_value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -310,6 +328,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -338,129 +357,128 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - obs_keys = self.actor_network.in_keys - tensordict_save = tensordict - tensordict = tensordict.clone(False) - act = tensordict.get(self.tensor_keys.action) - action_shape = act.shape - action_device = act.device - # computing early for reprod - noise = torch.normal( - mean=torch.zeros(action_shape), - std=torch.full(action_shape, self.policy_noise), - ).to(action_device) - noise = noise.clamp(-self.noise_clip, self.noise_clip) - - tensordict_actor_grad = tensordict.select( - *obs_keys - ) # to avoid overwriting keys - next_td_actor = step_mdp(tensordict).select( - *self.actor_network.in_keys - ) # next_observation -> - tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - # DO NOT call contiguous bc we'll update the tds later - actor_output_td = self._vmap_actor_network00( - tensordict_actor, - self._cached_stack_actor_params, - ) - # add noise to target policy - actor_output_td1 = actor_output_td[1] - next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp( - self.min_action, self.max_action + def actor_loss(self, tensordict): + tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys) + tensordict_actor_grad = self.actor_network( + tensordict_actor_grad, self.actor_network_params ) - actor_output_td1.set(self.tensor_keys.action, next_action) - tensordict_actor.set( - self.tensor_keys.action, - actor_output_td.get(self.tensor_keys.action), - ) - - # repeat tensordict_actor to match the qvalue size - _actor_loss_td = ( - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + actor_loss_td = tensordict_actor_grad.select( + *self.qvalue_network.in_keys + ).expand( + self.num_qvalue_nets, *tensordict_actor_grad.batch_size ) # for actor loss - _qval_td = tensordict.select(*self.qvalue_network.in_keys).expand( - self.num_qvalue_nets, - *tensordict.select(*self.qvalue_network.in_keys).batch_size, - ) # for qvalue loss - _next_val_td = ( - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) - ) # for next value estimation - tensordict_qval = torch.cat( - [ - _actor_loss_td, - _next_val_td, - _qval_td, - ], - 0, - ) - - # cat params - qvalue_params = torch.cat( - [ + state_action_value_actor = ( + self._vmap_qvalue_network00( + actor_loss_td, self._cached_detach_qvalue_network_params, - self.target_qvalue_network_params, - self.qvalue_network_params, - ], - 0, - ) - tensordict_qval = self._vmap_qvalue_network00( - tensordict_qval, - qvalue_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) ) + loss_actor = -(state_action_value_actor[0]).mean() + metadata = { + "state_action_value_actor": state_action_value_actor.mean().detach(), + } + return loss_actor, metadata + + def value_loss(self, tensordict): + tensordict = tensordict.clone(False) + + act = tensordict.get(self.tensor_keys.action) - state_action_value = tensordict_qval.get( - self.tensor_keys.state_action_value - ).squeeze(-1) - ( - state_action_value_actor, - next_state_action_value_qvalue, - state_action_value_qvalue, - ) = state_action_value.split( - [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], - dim=0, + # computing early for reprod + noise = (torch.randn_like(act) * self.policy_noise).clamp( + -self.noise_clip, self.noise_clip ) - loss_actor = -(state_action_value_actor.min(0)[0]).mean() + with torch.no_grad(): + next_td_actor = step_mdp(tensordict).select( + *self.actor_network.in_keys + ) # next_observation -> + next_td_actor = self.actor_network( + next_td_actor, self.target_actor_network_params + ) + next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( + self.min_action, self.max_action + ) + next_td_actor.set( + self.tensor_keys.action, + next_action, + ) + next_val_td = next_td_actor.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, *next_td_actor.batch_size + ) # for next value estimation + next_target_q1q2 = ( + self._vmap_qvalue_network00( + next_val_td, + self.target_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + # min over the next target qvalues + next_target_qvalue = next_target_q1q2.min(0)[0] - next_state_value = next_state_action_value_qvalue.min(0)[0] + # set next target qvalues tensordict.set( ("next", self.tensor_keys.state_action_value), - next_state_value.unsqueeze(-1), + next_target_qvalue.unsqueeze(-1), + ) + + qval_td = tensordict.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict.batch_size, ) + # preditcted current qvalues + current_qvalue = ( + self._vmap_qvalue_network00( + qval_td, + self.qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done)) target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - pred_val = state_action_value_qvalue - td_error = (pred_val - target_value).pow(2) + + td_error = (current_qvalue - target_value).pow(2) loss_qval = ( distance_loss( - pred_val, - target_value.expand_as(pred_val), + current_qvalue, + target_value.expand_as(current_qvalue), loss_function=self.loss_function, ) .mean(-1) .sum() - * 0.5 ) + metadata = { + "td_error": td_error, + "next_state_value": next_target_qvalue.mean().detach(), + "pred_value": current_qvalue.mean().detach(), + "target_value": target_value.mean().detach(), + } - tensordict_save.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) + return loss_qval, metadata + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict_save = tensordict + loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.value_loss(tensordict_save) + tensordict_save.set( + self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] + ) if not loss_qval.shape == loss_actor.shape: raise RuntimeError( f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" ) td_out = TensorDict( source={ - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), - "pred_value": pred_val.mean().detach(), - "state_action_value_actor": state_action_value_actor.mean().detach(), - "next_state_value": next_state_value.mean().detach(), - "target_value": target_value.mean().detach(), + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, + **metadata_actor, + **metadata_value, }, batch_size=[], ) @@ -493,5 +511,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.state_action_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 662486e900b..bc678ed0154 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -141,33 +141,20 @@ def __init__( self, loss_module: "LossModule", # noqa: F821 ): + from torchrl.objectives.common import LossModule + + if not isinstance(loss_module, LossModule): + raise ValueError("The loss_module must be a LossModule instance.") _has_update_associated = getattr(loss_module, "_has_update_associated", None) - loss_module._has_update_associated = True + for k in loss_module._has_update_associated.keys(): + loss_module._has_update_associated[k] = True try: _target_names = [] - # for properties - for name in loss_module.__class__.__dict__: - if ( - name.startswith("target_") - and (name.endswith("params") or name.endswith("buffers")) - and (getattr(loss_module, name) is not None) - ): + for name, _ in loss_module.named_children(): + # the TensorDictParams is a nn.Module instance + if name.startswith("target_") and name.endswith("_params"): _target_names.append(name) - # for regular lists: raise an exception - for name in loss_module.__dict__: - if ( - name.startswith("target_") - and (name.endswith("params") or name.endswith("buffers")) - and (getattr(loss_module, name) is not None) - ): - raise RuntimeError( - "Your module seems to have a target tensor list contained " - "in a non-dynamic structure (such as a list). If the " - "module is cast onto a device, the reference to these " - "tensors will be lost." - ) - if len(_target_names) == 0: raise RuntimeError( "Did not find any target parameters or buffers in the loss module." @@ -191,7 +178,8 @@ def __init__( self.init_() _has_update_associated = True finally: - loss_module._has_update_associated = _has_update_associated + for k in loss_module._has_update_associated.keys(): + loss_module._has_update_associated[k] = _has_update_associated @property def _targets(self): diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 31c8c291c5b..acd2307a0c3 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -171,12 +171,14 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"advantage"``. value_target (NestedKey): The input tensordict key where the target state value is written to. Will be used for the underlying value estimator Defaults to ``"value_target"``. - value_key (NestedKey): The input tensordict key where the state value is expected. + value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - reward_key (NestedKey): The input tensordict key where the reward is written to. + reward (NestedKey): The input tensordict key where the reward is written to. Defaults to ``"reward"``. - done_key (NestedKey): The key in the input TensorDict that indicates + done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Defaults to ``"terminated"``. steps_to_next_obs_key (NestedKey): The key in the input tensordict that indicates the number of steps to the next observation. Defaults to ``"steps_to_next_obs"``. @@ -187,6 +189,7 @@ class _AcceptedKeys: value: NestedKey = "state_value" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" default_keys = _AcceptedKeys() @@ -212,6 +215,10 @@ def reward_key(self): def done_key(self): return self.tensor_keys.done + @property + def terminated_key(self): + return self.tensor_keys.terminated + @property def steps_to_next_obs_key(self): return self.tensor_keys.steps_to_next_obs @@ -230,10 +237,14 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are - the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the TDEstimate. + The data passed to this module should be structured as + :obj:`[*B, T, *F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the + feature dimension(s). The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -302,6 +313,7 @@ def in_keys(self): + [ ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), ] + [("next", in_key) for in_key in self.value_network.in_keys] ) @@ -483,10 +495,14 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are - the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, and ``"next"`` + tensordict state as returned by the environment) necessary to + compute the value estimates and the TDEstimate. + The data passed to this module should be structured as + :obj:`[*B, T, *F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the + feature dimension(s). The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -507,7 +523,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward}}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -524,7 +541,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -587,8 +605,13 @@ def value_estimate( next_value = self._next_value(tensordict, target_params, kwargs=kwargs) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) value_target = td0_return_estimate( - gamma=gamma, next_state_value=next_value, reward=reward, done=done + gamma=gamma, + next_state_value=next_value, + reward=reward, + done=done, + terminated=terminated, ) return value_target @@ -674,10 +697,13 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the TDEstimate. + The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -698,7 +724,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward}}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -715,7 +742,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -779,8 +807,14 @@ def value_estimate( next_value = self._next_value(tensordict, target_params, kwargs=kwargs) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) value_target = vec_td1_return_estimate( - gamma, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) return value_target @@ -873,10 +907,13 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDLambdaEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the TDLambdaEstimate. + The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -898,7 +935,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward}}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -916,7 +954,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -980,13 +1019,26 @@ def value_estimate( next_value = self._next_value(tensordict, target_params, kwargs=kwargs) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.done), default=done) if self.vectorized: val = vec_td_lambda_return_estimate( - gamma, lmbda, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + lmbda, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) else: val = td_lambda_return_estimate( - gamma, lmbda, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + lmbda, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) return val @@ -1096,10 +1148,13 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", "reward", "done" and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the GAE. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -1122,7 +1177,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward, "terminated": terminated}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -1141,7 +1197,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -1178,6 +1235,7 @@ def forward( next_value = tensordict.get(("next", self.tensor_keys.value)) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.done), default=done) if self.vectorized: adv, value_target = vec_generalized_advantage_estimate( gamma, @@ -1185,7 +1243,8 @@ def forward( value, next_value, reward, - done, + done=done, + terminated=done, time_dim=tensordict.ndim - 1, ) else: @@ -1195,7 +1254,8 @@ def forward( value, next_value, reward, - done, + done=done, + terminated=terminated, time_dim=tensordict.ndim - 1, ) @@ -1254,8 +1314,16 @@ def value_estimate( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) _, value_target = vec_generalized_advantage_estimate( - gamma, lmbda, value, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + lmbda, + value, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) return value_target diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index ccd0966bbf6..318ba09d02c 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -2,8 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import math + +import warnings from functools import wraps from typing import Optional, Tuple, Union @@ -51,7 +54,9 @@ def _transpose_time(fun): ) @wraps(fun) - def transposed_fun(*args, time_dim=-2, **kwargs): + def transposed_fun(*args, **kwargs): + time_dim = kwargs.pop("time_dim", -2) + def transpose_tensor(tensor): if ( not isinstance(tensor, (torch.Tensor, MemmapTensor)) @@ -77,7 +82,7 @@ def transpose_tensor(tensor): if time_dim != -2: args, single_dim = zip(*(transpose_tensor(arg) for arg in args)) single_dim = any(single_dim) - for k, item in kwargs.items(): + for k, item in list(kwargs.items()): item, sd = transpose_tensor(item) single_dim = single_dim or sd kwargs[k] = item @@ -116,6 +121,7 @@ def generalized_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Generalized advantage estimate of a trajectory. @@ -129,27 +135,37 @@ def generalized_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. time_dim (int): dimension where the time is unrolled. Defaults to -2. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) dtype = next_state_value.dtype device = state_value.device - not_done = (~done).int() + not_terminated = (~terminated).int() *batch_size, time_steps, lastdim = not_done.shape advantage = torch.empty( *batch_size, time_steps, lastdim, device=device, dtype=dtype ) prev_advantage = 0 - gnotdone = gamma * not_done - delta = reward + (gnotdone * next_state_value) - state_value - discount = lmbda * gnotdone + g_not_terminated = gamma * not_terminated + delta = reward + (g_not_terminated * next_state_value) - state_value + discount = lmbda * gamma * not_done for t in reversed(range(time_steps)): prev_advantage = advantage[..., t, :] = delta[..., t, :] + ( prev_advantage * discount[..., t, :] @@ -187,6 +203,7 @@ def _fast_vec_gae( state_value: torch.Tensor, next_state_value: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor, gamma: float, lmbda: float, thr: float = 1e-7, @@ -200,7 +217,8 @@ def _fast_vec_gae( reward (torch.Tensor): a [*B, T, F] tensor containing rewards state_value (torch.Tensor): a [*B, T, F] tensor containing state values (value function) next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function) - done (torch.Tensor): a [B, T] boolean tensor containing the done states + done (torch.Tensor): a [B, T] boolean tensor containing the done states. + terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states. gamma (scalar): the gamma decay (trajectory discount) lmbda (scalar): the lambda decay (exponential mean discount) thr (float): threshold for the filter. Below this limit, components will ignored. @@ -213,13 +231,14 @@ def _fast_vec_gae( # _gen_num_per_traj and _split_and_pad_sequence need # time dimension at last position done = done.transpose(-2, -1) + terminated = terminated.transpose(-2, -1) reward = reward.transpose(-2, -1) state_value = state_value.transpose(-2, -1) next_state_value = next_state_value.transpose(-2, -1) gammalmbda = gamma * lmbda - not_done = (~done).int() - td0 = reward + not_done * gamma * next_state_value - state_value + not_terminated = (~terminated).int() + td0 = reward + not_terminated * gamma * next_state_value - state_value num_per_traj = _get_num_per_traj(done) td0_flat, mask = _split_and_pad_sequence(td0, num_per_traj, return_mask=True) @@ -246,6 +265,7 @@ def vec_generalized_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Vectorized Generalized advantage estimate of a trajectory. @@ -259,23 +279,33 @@ def vec_generalized_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. time_dim (int): dimension where the time is unrolled. Defaults to -2. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) dtype = state_value.dtype - not_done = (~done).to(dtype) - *batch_size, time_steps, lastdim = not_done.shape + *batch_size, time_steps, lastdim = terminated.shape value = gamma * lmbda if isinstance(value, torch.Tensor) and value.numel() > 1: # create tensor while ensuring that gradients are passed + not_done = (~done).to(dtype) gammalmbdas = not_done * value else: # when gamma and lmbda are scalars, use fast_vec_gae implementation @@ -284,6 +314,7 @@ def vec_generalized_advantage_estimate( state_value=state_value, next_state_value=next_state_value, done=done, + terminated=terminated, gamma=gamma, lmbda=lmbda, ) @@ -299,7 +330,8 @@ def vec_generalized_advantage_estimate( first_below_thr = torch.where(first_below_thr)[0][0].item() gammalmbdas = gammalmbdas[..., :first_below_thr, :] - td0 = reward + not_done * gamma * next_state_value - state_value + not_terminated = (~terminated).to(dtype) + td0 = reward + not_terminated * gamma * next_state_value - state_value if len(batch_size) > 1: td0 = td0.flatten(0, len(batch_size) - 1) @@ -336,6 +368,7 @@ def td0_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """TD(0) advantage estimate of a trajectory. @@ -346,15 +379,25 @@ def td0_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) - returns = td0_return_estimate(gamma, next_state_value, reward, done) + returns = td0_return_estimate(gamma, next_state_value, reward, terminated) advantage = returns - state_value return advantage @@ -363,8 +406,11 @@ def td0_return_estimate( gamma: float, next_state_value: torch.Tensor, reward: torch.Tensor, - done: torch.Tensor, + terminated: torch.Tensor, + *, + done: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + # noqa: D417 """TD(0) discounted return estimate of a trajectory. Also known as bootstrapped Temporal Difference or one-step return. @@ -375,16 +421,24 @@ def td0_return_estimate( must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor - done (Tensor): boolean flag for end of episode. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. + + Keyword Args: + done (Tensor): Deprecated. Use ``terminated`` instead. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if done is not None: + warnings.warn( + "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." + ) + if not (next_state_value.shape == reward.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) - not_done = (~done).int() - advantage = reward + gamma * not_done * next_state_value + not_terminated = (~terminated).int() + advantage = reward + gamma * not_terminated * next_state_value return advantage @@ -399,6 +453,7 @@ def td1_return_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -408,7 +463,9 @@ def td1_return_estimate( gamma (scalar): exponential mean discount. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -436,9 +493,12 @@ def td1_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() + not_terminated = (~terminated).int() returns = torch.empty_like(next_state_value) @@ -456,19 +516,29 @@ def td1_return_estimate( "rolling_gamma=False is expected only with time-sensitive gamma values" ) + done_but_not_terminated = (done & ~terminated).int() if rolling_gamma: - gamma = gamma * not_done + gamma = gamma * not_terminated g = next_state_value[..., -1, :] for i in reversed(range(T)): - g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * g + # if not done (and hence not terminated), get the bootstrapped value + # if done but not terminated, get nex_val + # if terminated, take nothing (gamma = 0) + dnt = done_but_not_terminated[..., i, :] + g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * ( + (1 - dnt) * g + dnt * next_state_value[..., i, :] + ) else: for k in range(T): - g = next_state_value[..., -1, :] + g = 0 _gamma = gamma[..., k, :] - nd = not_done + nd = not_terminated _gamma = _gamma.unsqueeze(-2) * nd for i in reversed(range(k, T)): - g = reward[..., i, :] + _gamma[..., i, :] * g + dnt = done_but_not_terminated[..., i, :] + g = reward[..., i, :] + _gamma[..., i, :] * ( + (1 - dnt) * g + dnt * next_state_value[..., i, :] + ) returns[..., k, :] = g return returns @@ -479,6 +549,7 @@ def td1_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -489,7 +560,9 @@ def td1_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -517,12 +590,26 @@ def td1_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) if not state_value.shape == next_state_value.shape: raise RuntimeError("shape of state_value and next_state_value must match") returns = td1_return_estimate( - gamma, next_state_value, reward, done, rolling_gamma, time_dim=time_dim + gamma, + next_state_value, + reward, + done, + terminated=terminated, + rolling_gamma=rolling_gamma, + time_dim=time_dim, ) advantage = returns - state_value return advantage @@ -533,7 +620,8 @@ def vec_td1_return_estimate( gamma, next_state_value, reward, - done, + done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: Optional[bool] = None, time_dim: int = -2, ): @@ -543,7 +631,9 @@ def vec_td1_return_estimate( gamma (scalar, Tensor): exponential mean discount. If tensor-valued, next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -576,6 +666,7 @@ def vec_td1_return_estimate( next_state_value=next_state_value, reward=reward, done=done, + terminated=terminated, rolling_gamma=rolling_gamma, lmbda=1, time_dim=time_dim, @@ -587,7 +678,8 @@ def vec_td1_advantage_estimate( state_value, next_state_value, reward, - done, + done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ): @@ -598,7 +690,9 @@ def vec_td1_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -626,11 +720,25 @@ def vec_td1_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) return ( vec_td1_return_estimate( - gamma, next_state_value, reward, done, rolling_gamma, time_dim=time_dim + gamma, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + time_dim=time_dim, ) - state_value ) @@ -648,6 +756,7 @@ def td_lambda_return_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -658,7 +767,9 @@ def td_lambda_return_estimate( lmbda (scalar): trajectory discount. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -686,23 +797,26 @@ def td_lambda_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) - not_done = (~done).int() + not_terminated = (~terminated).int() returns = torch.empty_like(next_state_value) + next_state_value = next_state_value * not_terminated *batch, T, lastdim = returns.shape # if gamma is not a tensor of the same shape as other inputs, we use rolling_gamma = True single_gamma = False - if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape): + if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape): single_gamma = True gamma = torch.full_like(next_state_value, gamma) single_lambda = False - if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == not_done.shape): + if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape): single_lambda = True lmbda = torch.full_like(next_state_value, lmbda) @@ -712,26 +826,28 @@ def td_lambda_return_estimate( raise RuntimeError( "rolling_gamma=False is expected only with time-sensitive gamma or lambda values" ) - if rolling_gamma: - gamma = gamma * not_done g = next_state_value[..., -1, :] for i in reversed(range(T)): + dn = done[..., i, :].int() + nv = next_state_value[..., i, :] + lmd = lmbda[..., i, :] + # if done, the bootstrapped gain is the next value, otherwise it's the + # value we computed during the previous iter + g = g * (1 - dn) + nv * dn g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * ( - (1 - lmbda[..., i, :]) * next_state_value[..., i, :] - + lmbda[..., i, :] * g + (1 - lmd) * nv + lmd * g ) else: for k in range(T): g = next_state_value[..., -1, :] _gamma = gamma[..., k, :] _lambda = lmbda[..., k, :] - nd = not_done - _gamma = _gamma.unsqueeze(-2) * nd for i in reversed(range(k, T)): - g = reward[..., i, :] + _gamma[..., i, :] * ( - (1 - _lambda) * next_state_value[..., i, :] + _lambda * g - ) + dn = done[..., i, :].int() + nv = next_state_value[..., i, :] + g = g * (1 - dn) + nv * dn + g = reward[..., i, :] + _gamma * ((1 - _lambda) * nv + _lambda * g) returns[..., k, :] = g return returns @@ -744,6 +860,7 @@ def td_lambda_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -755,7 +872,9 @@ def td_lambda_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -783,12 +902,27 @@ def td_lambda_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) if not state_value.shape == next_state_value.shape: raise RuntimeError("shape of state_value and next_state_value must match") returns = td_lambda_return_estimate( - gamma, lmbda, next_state_value, reward, done, rolling_gamma, time_dim=time_dim + gamma, + lmbda, + next_state_value, + reward, + done, + terminated=terminated, + rolling_gamma=rolling_gamma, + time_dim=time_dim, ) advantage = returns - state_value return advantage @@ -800,6 +934,7 @@ def _fast_td_lambda_return_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor, thr: float = 1e-7, ): """Fast vectorized TD lambda return estimate. @@ -812,7 +947,8 @@ def _fast_td_lambda_return_estimate( lmbda (scalar): the lambda decay (exponential mean discount) next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function) reward (torch.Tensor): a [*B, T, F] tensor containing rewards - done (torch.Tensor): a [B, T] boolean tensor containing the done states + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for end of episode. thr (float): threshold for the filter. Below this limit, components will ignored. Defaults to 1e-7. @@ -822,23 +958,25 @@ def _fast_td_lambda_return_estimate( """ device = reward.device done = done.transpose(-2, -1) + terminated = terminated.transpose(-2, -1) reward = reward.transpose(-2, -1) next_state_value = next_state_value.transpose(-2, -1) + # the only valid next states are those where the trajectory does not terminate + next_state_value = (~terminated).int() * next_state_value + gamma_tensor = torch.tensor([gamma], device=device) gammalmbda = gamma_tensor * lmbda - not_done = (~done).int() num_per_traj = _get_num_per_traj(done) - nvalue_ndone = not_done * next_state_value - t = nvalue_ndone * gamma_tensor * (1 - lmbda) + reward - v3 = torch.zeros_like(t, device=device) - v3[..., -1] = nvalue_ndone[..., -1].clone() + done = done.clone() + done[..., -1] = 1 + not_done = (~done).int() - t_flat, mask = _split_and_pad_sequence( - t + v3 * gammalmbda, num_per_traj, return_mask=True - ) + t = reward + next_state_value * gamma_tensor * (1 - not_done * lmbda) + + t_flat, mask = _split_and_pad_sequence(t, num_per_traj, return_mask=True) gammalmbdas = _geom_series_like(t_flat[0], gammalmbda, thr=thr) @@ -855,6 +993,7 @@ def vec_td_lambda_return_estimate( next_state_value, reward, done, + terminated: torch.Tensor | None = None, rolling_gamma: Optional[bool] = None, time_dim: int = -2, ): @@ -868,7 +1007,9 @@ def vec_td_lambda_return_estimate( must be a [Batch x TimeSteps x 1] tensor reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -896,7 +1037,9 @@ def vec_td_lambda_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) gamma_thr = 1e-7 @@ -916,6 +1059,7 @@ def _is_scalar(tensor): next_state_value=next_state_value, reward=reward, done=done, + terminated=terminated, thr=gamma_thr, ) @@ -930,16 +1074,18 @@ def _is_scalar(tensor): """Vectorized version of td_lambda_advantage_estimate""" device = reward.device not_done = (~done).int() + not_terminated = (~terminated).int().transpose(-2, -1).unsqueeze(-2) + if len(batch): + not_terminated = not_terminated.flatten(0, len(batch)) + next_state_value = next_state_value * not_terminated if rolling_gamma is None: rolling_gamma = True - if rolling_gamma: - gamma = gamma * not_done - gammas = _make_gammas_tensor(gamma, T, rolling_gamma) - if not rolling_gamma: - done_follows_done = done[..., 1:, :][done[..., :-1, :]].all() - if not done_follows_done: + terminated_follows_terminated = terminated[..., 1:, :][ + terminated[..., :-1, :] + ].all() + if not terminated_follows_terminated: raise NotImplementedError( "When using rolling_gamma=False and vectorized TD(lambda) with time-dependent gamma, " "make sure that conseducitve trajectories are separated as different batch " @@ -948,46 +1094,47 @@ def _is_scalar(tensor): "consider using the non-vectorized version of the return computation or splitting " "your trajectories." ) - else: - gammas[..., 1:, :] = gammas[..., 1:, :] * not_done.view(-1, 1, T, 1) - gammas_cp = torch.cumprod(gammas, -2) - - lambdas = torch.ones(T + 1, 1, device=device) - lambdas[1:] = lmbda - lambdas_cp = torch.cumprod(lambdas, -2) - - gammas = gammas[..., 1:, :] - lambdas = lambdas[1:] - - dec = gammas_cp * lambdas_cp - if rolling_gamma in (None, True): + if rolling_gamma: + # Make the coefficient table + gammas = _make_gammas_tensor(gamma * not_done, T, rolling_gamma) + gammas_cp = torch.cumprod(gammas, -2) + lambdas = torch.ones(T + 1, 1, device=device) + lambdas[1:] = lmbda + lambdas_cp = torch.cumprod(lambdas, -2) + lambdas = lambdas[1:] + dec = gammas_cp * lambdas_cp + + gammas = _make_gammas_tensor(gamma, T, rolling_gamma) + gammas = gammas[..., 1:, :] if gammas.ndimension() == 4 and gammas.shape[1] > 1: gammas = gammas[:, :1] if lambdas.ndimension() == 4 and lambdas.shape[1] > 1: lambdas = lambdas[:, :1] - v3 = (gammas * lambdas).squeeze(-1) * next_state_value + + not_done = not_done.transpose(-2, -1).unsqueeze(-2) + if len(batch): + not_done = not_done.flatten(0, len(batch)) + # lambdas = lambdas * not_done + + v3 = (gammas * lambdas).squeeze(-1) * next_state_value * not_done v3[..., :-1] = 0 out = _custom_conv1d( - reward + (gammas * (1 - lambdas)).squeeze(-1) * next_state_value + v3, dec + reward + + gammas.squeeze(-1) + * next_state_value + * (1 - lambdas.squeeze(-1) * not_done) + + v3, + dec, ) + return out.view(*batch, lastdim, T).transpose(-2, -1) else: - v1 = _custom_conv1d(reward, dec) - - if gammas.ndimension() == 4 and gammas.shape[1] > 1: - gammas = gammas[:, :, :1].transpose(1, 2) - if lambdas.ndimension() == 4 and lambdas.shape[1] > 1: - lambdas = lambdas[:, :, :1].transpose(1, 2) - - v2 = _custom_conv1d( - next_state_value * not_done.view_as(next_state_value), - dec * (gammas * (1 - lambdas)).transpose(1, 2), + raise NotImplementedError( + "The vectorized version of TD(lambda) with rolling_gamma=False is currently not available. " + "To use this feature, use the non-vectorized version of TD(lambda). You can expect " + "good speed improvements by decorating the function with torch.compile!" ) - v3 = next_state_value * not_done.view_as(next_state_value) - v3[..., :-1] = 0 - v3 = _custom_conv1d(v3, dec * (gammas * lambdas).transpose(1, 2)) - return (v1 + v2 + v3).view(*batch, lastdim, T).transpose(-2, -1) def vec_td_lambda_advantage_estimate( @@ -997,6 +1144,7 @@ def vec_td_lambda_advantage_estimate( next_state_value, reward, done, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ): @@ -1008,7 +1156,9 @@ def vec_td_lambda_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -1036,7 +1186,15 @@ def vec_td_lambda_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) return ( vec_td_lambda_return_estimate( @@ -1044,8 +1202,9 @@ def vec_td_lambda_advantage_estimate( lmbda, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) - state_value @@ -1069,7 +1228,8 @@ def reward2go( Args: reward (torch.Tensor): A tensor containing the rewards received at each time step over multiple trajectories. - done (torch.Tensor): A tensor with done (or truncated) states. + done (Tensor): boolean flag for end of episode. Differs from + truncated, where the episode did not end but was interrupted. gamma (float, optional): The discount factor to use for computing the discounted cumulative sum of rewards. Defaults to 1.0. time_dim (int): dimension where the time is unrolled. Defaults to -2. diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index b5e9ce73319..e8e610af122 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -191,20 +191,20 @@ def _flatten_batch(tensor): return tensor.flatten(0, -1) -def _get_num_per_traj(dones_and_truncated): +def _get_num_per_traj(done): """Because we mark the end of each batch with a truncated signal, we can concatenate them. Args: - dones_and_truncated (torch.Tensor): A done or truncated mark of shape [*B, T] + done (torch.Tensor): A done or truncated mark of shape [*B, T] Returns: A list of integers representing the number of steps in each trajectory """ - dones_and_truncated = dones_and_truncated.clone() - dones_and_truncated[..., -1] = True + done = done.clone() + done[..., -1] = True # TODO: find a way of copying once only, eg not using reshape - num_per_traj = torch.where(dones_and_truncated.reshape(-1))[0] + 1 + num_per_traj = torch.where(done.reshape(-1))[0] + 1 num_per_traj[1:] = num_per_traj[1:] - num_per_traj[:-1] return num_per_traj diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index be720e7687c..726d29ea051 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -3,4 +3,5 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger from .recorder import TensorDictRecorder, VideoRecorder diff --git a/torchrl/record/loggers/__init__.py b/torchrl/record/loggers/__init__.py index a14b39138e9..92714675046 100644 --- a/torchrl/record/loggers/__init__.py +++ b/torchrl/record/loggers/__init__.py @@ -5,9 +5,9 @@ from .common import Logger -# from .csv import CSVLogger -# from .mlflow import MLFlowLogger -# from .tensorboard import TensorboardLogger +from .csv import CSVLogger +from .mlflow import MLFlowLogger +from .tensorboard import TensorboardLogger from .utils import generate_exp_name, get_logger -# from .wandb import WandbLogger +from .wandb import WandbLogger diff --git a/torchrl/record/loggers/common.py b/torchrl/record/loggers/common.py index 0aa201b967c..b8325763166 100644 --- a/torchrl/record/loggers/common.py +++ b/torchrl/record/loggers/common.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import abc -from typing import Sequence +from typing import Dict, Sequence, Union from torch import Tensor @@ -33,7 +33,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non ... @abc.abstractmethod - def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 + def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 ... @abc.abstractmethod diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index ea94444ddc2..4f7ae47606a 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -5,7 +5,7 @@ import os from collections import defaultdict from pathlib import Path -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence, Union import torch from torch import Tensor @@ -21,10 +21,12 @@ def __init__(self, log_dir: str): self.videos_counter = defaultdict(lambda: 0) self.text_counter = defaultdict(lambda: 0) self.log_dir = log_dir - os.makedirs(self.log_dir) - os.makedirs(os.path.join(self.log_dir, "scalars")) - os.makedirs(os.path.join(self.log_dir, "videos")) - os.makedirs(os.path.join(self.log_dir, "texts")) + os.makedirs(self.log_dir, exist_ok=True) + os.makedirs(os.path.join(self.log_dir, "scalars"), exist_ok=True) + os.makedirs(os.path.join(self.log_dir, "videos"), exist_ok=True) + os.makedirs(os.path.join(self.log_dir, "texts"), exist_ok=True) + + self.files = {} def add_scalar(self, name: str, value: float, global_step: Optional[int] = None): if global_step is None: @@ -32,8 +34,11 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) value = float(value) self.scalars[name].append((global_step, value)) filepath = os.path.join(self.log_dir, "scalars", "".join([name, ".csv"])) - with open(filepath, "a") as fd: - fd.write(",".join([str(global_step), str(value)]) + "\n") + if filepath not in self.files: + self.files[filepath] = open(filepath, "a") + fd = self.files[filepath] + fd.write(",".join([str(global_step), str(value)]) + "\n") + fd.flush() def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs): if global_step is None: @@ -53,12 +58,19 @@ def add_text(self, tag, text, global_step: Optional[int] = None): filepath = os.path.join( self.log_dir, "texts", "".join([tag, str(global_step)]) + ".txt" ) - with open(filepath, "w+") as f: - f.writelines(text) + if filepath not in self.files: + self.files[filepath] = open(filepath, "w+") + fd = self.files[filepath] + fd.writelines(text) + fd.flush() def __repr__(self) -> str: return f"CSVExperiment(log_dir={self.log_dir})" + def __del__(self): + for val in getattr(self, "files", {}).values(): + val.close() + class CSVLogger(Logger): """A minimal-dependecy CSV-logger. @@ -112,13 +124,13 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non **kwargs, ) - def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 + def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: - cfg (DictConfig): The configuration of the experiment. + cfg (DictConfig or dict): The configuration of the experiment. """ - txt = "\n\t".join([f"{k}: {val}" for k, val in sorted(vars(cfg).items())]) + txt = "\n".join([f"{k}: {val}" for k, val in sorted(cfg.items())]) self.experiment.add_text("hparams", txt) def __repr__(self) -> str: diff --git a/torchrl/record/loggers/mlflow.py b/torchrl/record/loggers/mlflow.py index 5253c466799..34f8f4f8d3a 100644 --- a/torchrl/record/loggers/mlflow.py +++ b/torchrl/record/loggers/mlflow.py @@ -2,36 +2,20 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util import os from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, Optional, Sequence, Union -try: - import torchvision - - _has_tv = True -except ImportError: - _has_tv = False from torch import Tensor -from .common import Logger - -MLFLOW_ERR = None -try: - import mlflow +from torchrl.record.loggers.common import Logger - _has_mlflow = True -except ImportError as err: - _has_mlflow = False - MLFLOW_ERR = err +_has_tv = importlib.util.find_spec("torchvision") is not None -try: - from omegaconf import OmegaConf - - _has_omgaconf = True -except ImportError: - _has_omgaconf = False +_has_mlflow = importlib.util.find_spec("mlflow") is not None +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None class MLFlowLogger(Logger): @@ -49,6 +33,8 @@ def __init__( tags: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: + import mlflow + self._mlflow_kwargs = { "name": exp_name, "artifact_location": tracking_uri, @@ -58,14 +44,16 @@ def __init__( super().__init__(exp_name=exp_name, log_dir=tracking_uri) self.video_log_counter = 0 - def _create_experiment(self) -> "mlflow.ActiveRun": + def _create_experiment(self) -> "mlflow.ActiveRun": # noqa + import mlflow + """Creates an mlflow experiment. Returns: mlflow.ActiveRun: The mlflow experiment object. """ if not _has_mlflow: - raise ImportError("MLFlow is not installed") from MLFLOW_ERR + raise ImportError("MLFlow is not installed") self.id = mlflow.create_experiment(**self._mlflow_kwargs) return mlflow.start_run(experiment_id=self.id) @@ -78,6 +66,8 @@ def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> Non step (int, optional): The step at which the scalar is logged. Defaults to None. """ + import mlflow + mlflow.set_experiment(experiment_id=self.id) mlflow.log_metric(key=name, value=value, step=step) @@ -91,6 +81,9 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: **kwargs: Other keyword arguments. By construction, log_video supports 'step' (integer indicating the step index) and 'fps' (default: 6). """ + import mlflow + import torchvision + if not _has_tv: raise ImportError( "Loggin a video with MLFlow requires torchvision to be installed." @@ -112,14 +105,17 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: torchvision.io.write_video(filename=f.name, video_array=video, fps=fps) mlflow.log_artifact(f.name, "videos") - def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 + def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: - cfg (DictConfig): The configuration of the experiment. + cfg (DictConfig or dict): The configuration of the experiment. """ + import mlflow + from omegaconf import OmegaConf + mlflow.set_experiment(experiment_id=self.id) - if type(cfg) is not dict and _has_omgaconf: + if type(cfg) is not dict and _has_omegaconf: cfg = OmegaConf.to_container(cfg, resolve=True) mlflow.log_params(cfg) diff --git a/torchrl/record/loggers/tensorboard.py b/torchrl/record/loggers/tensorboard.py index 557193cb038..12e52a91a64 100644 --- a/torchrl/record/loggers/tensorboard.py +++ b/torchrl/record/loggers/tensorboard.py @@ -2,20 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util + import os -from typing import Sequence +from typing import Dict, Sequence, Union from torch import Tensor from .common import Logger - -try: - from torch.utils.tensorboard import SummaryWriter - - _has_tb = True -except ImportError: - _has_tb = False +_has_tb = importlib.util.find_spec("tensorboard") is not None +_has_omgaconf = importlib.util.find_spec("omegaconf") is not None class TensorboardLogger(Logger): @@ -34,7 +31,7 @@ def __init__(self, exp_name: str, log_dir: str = "tb_logs") -> None: self._has_imported_moviepy = False - def _create_experiment(self) -> "SummaryWriter": + def _create_experiment(self) -> "SummaryWriter": # noqa """Creates a tensorboard experiment. Args: @@ -47,6 +44,8 @@ def _create_experiment(self) -> "SummaryWriter": if not _has_tb: raise ImportError("torch.utils.tensorboard could not be imported") + from torch.utils.tensorboard import SummaryWriter + log_dir = str(os.path.join(self.log_dir, self.exp_name)) return SummaryWriter(log_dir=log_dir) @@ -92,15 +91,23 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non **kwargs, ) - def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 + def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: - cfg (DictConfig): The configuration of the experiment. + cfg (DictConfig or dict): The configuration of the experiment. """ - txt = "\n\t".join([f"{k}: {val}" for k, val in sorted(vars(cfg).items())]) - self.experiment.add_text("hparams", txt) + if type(cfg) is not dict and _has_omgaconf: + if not _has_omgaconf: + raise ImportError( + "OmegaConf could not be imported. " + "Cannot log hydra configs without OmegaConf." + ) + from omegaconf import OmegaConf + + cfg = OmegaConf.to_container(cfg, resolve=True) + self.experiment.add_hparams(cfg, metric_dict={}) def __repr__(self) -> str: return f"TensorboardLogger(experiment={self.experiment.__repr__()})" diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index fd5ab756f92..9a818753956 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -2,30 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util import os import warnings -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence, Union from torch import Tensor from .common import Logger - -try: - import wandb - - _has_wandb = True -except ImportError: - _has_wandb = False - - -try: - from omegaconf import OmegaConf - - _has_omgaconf = True -except ImportError: - _has_omgaconf = False +_has_wandb = importlib.util.find_spec("wandb") is not None +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None class WandbLogger(Logger): @@ -92,11 +80,13 @@ def _create_experiment(self) -> "WandbLogger": Returns: WandbLogger: The wandb experiment logger. """ + if not _has_wandb: + raise ImportError("Wandb is not installed") + import wandb + if self.offline: os.environ["WANDB_MODE"] = "dryrun" - if not _has_wandb: - raise ImportError("Wandb is not installed") return wandb.init(**self._wandb_kwargs) def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None: @@ -124,6 +114,8 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: (default is 'mp4') and 'fps' (default: 6). Other kwargs are passed as-is to the :obj:`experiment.log` method. """ + import wandb + # check for correct format of the video tensor ((N), T, C, H, W) # check that the color channel (C) is either 1 or 3 if video.dim() != 5 or video.size(dim=2) not in {1, 3}: @@ -159,19 +151,21 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: **kwargs, ) - def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 + def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: - cfg (DictConfig): The configuration of the experiment. + cfg (DictConfig or dict): The configuration of the experiment. """ - if type(cfg) is not dict and _has_omgaconf: - if not _has_omgaconf: + if type(cfg) is not dict and _has_omegaconf: + if not _has_omegaconf: raise ImportError( "OmegaConf could not be imported. " "Cannot log hydra configs without OmegaConf." ) + from omegaconf import OmegaConf + cfg = OmegaConf.to_container(cfg, resolve=True) self.experiment.config.update(cfg, allow_val_change=True) @@ -190,6 +184,8 @@ def log_histogram(self, name: str, data: Sequence, **kwargs): bins (str): One of {‘tensorflow’,’auto’, ‘fd’, …}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html """ + import wandb + num_bins = kwargs.pop("bins", None) step = kwargs.pop("step", None) extra_kwargs = {} diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 56c68e065ca..ba8ec2604fe 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from copy import copy from typing import Optional, Sequence import torch +from tensordict.tensordict import TensorDictBase + +from tensordict.utils import NestedKey + +from torchrl.envs.transforms import ObservationTransform, Transform +from torchrl.record.loggers import Logger + try: from torchvision.transforms.functional import center_crop as center_crop_fn from torchvision.utils import make_grid except ImportError: center_crop_fn = None -from tensordict.tensordict import TensorDictBase - -from torchrl.envs.transforms import ObservationTransform, Transform -from torchrl.record.loggers import Logger - class VideoRecorder(ObservationTransform): """Video Recorder transform. @@ -29,7 +32,7 @@ class VideoRecorder(ObservationTransform): logger (Logger): a Logger instance where the video should be written. tag (str): the video tag in the logger. - in_keys (Sequence[str], optional): keys to be read to produce the video. + in_keys (Sequence of NestedKey, optional): keys to be read to produce the video. Default is :obj:`"pixels"`. skip (int): frame interval in the output video. Default is 2. @@ -37,6 +40,8 @@ class VideoRecorder(ObservationTransform): make_grid (bool, optional): if ``True``, a grid is created assuming that a tensor of shape [B x W x H x 3] is provided, with B being the batch size. Default is True. + out_keys (sequence of NestedKey, optional): destination keys. Defaults + to ``in_keys`` if not provided. """ @@ -44,16 +49,18 @@ def __init__( self, logger: Logger, tag: str, - in_keys: Optional[Sequence[str]] = None, + in_keys: Optional[Sequence[NestedKey]] = None, skip: int = 2, center_crop: Optional[int] = None, make_grid: bool = True, + out_keys: Optional[Sequence[NestedKey]] = None, **kwargs, ) -> None: if in_keys is None: in_keys = ["pixels"] - - super().__init__(in_keys=in_keys) + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) video_kwargs = {"fps": 6} video_kwargs.update(kwargs) self.video_kwargs = video_kwargs diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index f14d8683d90..418dd638269 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -15,8 +15,8 @@ MultiSyncDataCollector, SyncDataCollector, ) -from torchrl.data import MultiStep -from torchrl.envs import ParallelEnv +from torchrl.data.postprocs import MultiStep +from torchrl.envs.batched_envs import ParallelEnv from torchrl.envs.common import EnvBase from torchrl.envs.utils import ExplorationType diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 620e09fced8..582dace8ab9 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -145,16 +145,6 @@ def make_env_transforms( if reward_scaling is not None: env.append_transform(RewardScaling(reward_loc, reward_scaling)) - double_to_float_list = [] - double_to_float_inv_list = [] - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - double_to_float_list += [ - "action", - ] - double_to_float_inv_list += ["action"] # DMControl requires double-precision if not from_pixels: selected_keys = [ key @@ -187,22 +177,13 @@ def make_env_transforms( ) ) - double_to_float_list.append(out_key) - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) + env.append_transform(DoubleToFloat()) if hasattr(cfg, "catframes") and cfg.catframes: env.append_transform(CatFrames(N=cfg.catframes, in_keys=[out_key], dim=-1)) else: - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) + env.append_transform(DoubleToFloat()) if hasattr(cfg, "gSDE") and cfg.gSDE: env.append_transform( diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 4d9e0198839..3951aa88c32 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -17,9 +17,9 @@ UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs import TensorDictPrimer, TransformedEnv from torchrl.envs.common import EnvBase from torchrl.envs.model_based.dreamer import DreamerEnv +from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( NoisyLinear, @@ -128,7 +128,7 @@ def make_dqn_actor( atoms = cfg.atoms if cfg.distributional else None linear_layer_class = torch.nn.Linear if not cfg.noisy else NoisyLinear - action_spec = env_specs["input_spec", "_action_spec", "action"] + action_spec = env_specs["input_spec", "full_action_spec", "action"] if action_spec.domain != "discrete": raise ValueError( f"env {proof_environment} has an action domain " @@ -158,7 +158,9 @@ def make_dqn_actor( "mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class}, } # automatically infer in key - (in_key,) = itertools.islice(env_specs["output_spec", "_observation_spec"], 1) + (in_key,) = itertools.islice( + env_specs["output_spec", "full_observation_spec"], 1 + ) actor_class = QValueActor actor_kwargs = {} @@ -167,7 +169,7 @@ def make_dqn_actor( # if action spec is modeled as categorical variable, we still need to have features equal # to the number of possible choices and also set categorical behavioural for actors. actor_kwargs.update({"action_space": "categorical"}) - out_features = env_specs["input_spec", "_action_spec", "action"].space.n + out_features = env_specs["input_spec", "full_action_spec", "action"].space.n else: out_features = action_spec.shape[0] @@ -373,8 +375,8 @@ def make_redq_model( dist_class = TanhNormal dist_kwargs = { - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, "tanh_loc": tanh_loc, } @@ -398,8 +400,8 @@ def make_redq_model( ) if action_spec.domain == "continuous": - min = action_spec.space.minimum - max = action_spec.space.maximum + min = action_spec.space.low + max = action_spec.space.high transform = SafeTanhTransform() if (min != -1).any() or (max != 1).any(): transform = d.ComposeTransform( diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 3bd283fd4b6..51228e66da1 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -366,7 +366,7 @@ # f_{\theta}(\text{observation}) = \mu_{\theta}(\text{observation}), \sigma^{+}_{\theta}(\text{observation}) # # The only extra-difficulty that is brought up here is to split our output in two -# equal parts and map the second to a scrictly positive space. +# equal parts and map the second to a strictly positive space. # # We design the policy in three steps: # diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py new file mode 100644 index 00000000000..c5ae154fcfd --- /dev/null +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -0,0 +1,793 @@ +# -*- coding: utf-8 -*- +""" +Multi-Agent Reinforcement Learning (PPO) with TorchRL Tutorial +=============================================================== +**Author**: `Matteo Bettini `_ + +This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to +solve a Multi-Agent Reinforcement Learning (MARL) problem. + +A code-only version of this tutorial is available in the +`TorchRL examples `__, +alongside other simple scripts for many MARL algorithms (QMIX, MADDPG, IQL). + +For ease of use, this tutorial will follow the general structure of the already available +`single agent PPO tutorial `__. +It is suggested but not mandatory to get familiar with that prior to starting this tutorial. + +In this tutorial, we will use the *Navigation* environment from +`VMAS `__, +a multi-robot simulator, also +based on PyTorch, that runs parallel batched simulation on device. + +In the *Navigation* environment, +we need to train multiple robots (spawned at random positions) +to navigate to their goals (also at random positions), while +using `LIDAR sensors `__ to avoid collisions among each other. + +.. figure:: https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/img/navigation.gif + :alt: Navigation + + Multi-agent *Navigation* scenario + +Key learnings: + +- How to create a multi-agent environment in TorchRL, how its specs work, and how it integrates with the library; +- How you use GPU vectorized environments in TorchRL; +- How to create different multi-agent network architectures in TorchRL (e.g., using parameter sharing, centralised critic) +- How we can use :class:`tensordict.TensorDict` to carry multi-agent data; +- How we can tie all the library components (collectors, modules, replay buffers, and losses) in a multi-agent MAPPO/IPPO training loop. + +""" + +###################################################################### +# If you are running this in Google Colab, make sure you install the following dependencies: +# +# .. code-block:: bash +# +# !pip3 install torchrl +# !pip3 install vmas==1.2.11 +# !pip3 install tqdm +# +# Proximal Policy Optimization (PPO) is a policy-gradient algorithm where a +# batch of data is being collected and directly consumed to train the policy to maximise +# the expected return given some proximality constraints. You can think of it +# as a sophisticated version of `REINFORCE `_, +# the foundational policy-optimization algorithm. For more information, see the +# `Proximal Policy Optimization Algorithms `_ paper. +# +# This type of algorithms is usually trained *on-policy*. This means that, at every learning iteration, we have a +# **sampling** and a **training** phase. In the **sampling** phase of iteration :math:`t`, rollouts are collected +# form agents' interactions in the environment using the current policies :math:`\mathbf{\pi}_t`. +# In the **training** phase, all the collected rollouts are immediately fed to the training process to perform +# backpropagation. This leads to updated policies which are then used again for sampling. +# The execution of this process in a loop constitutes *on-policy learning*. +# +# .. figure:: https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/img/on_policy_vmas.png +# :alt: On-policy learning +# +# On-policy learning +# +# +# In the training phase of the PPO algorithm, a *critic* is used to estimate the goodness of the actions +# taken by the policy. The critic learns to approximate the value (mean discounted return) of a specific state. +# The PPO loss then compares the actual return obtained by the policy to the one estimated by the critic to determine +# the advantage of the action taken and guide the policy optimization. +# +# In multi-agent settings, things are a bit different. We now have multiple policies :math:`\mathbf{\pi}`, +# one for each agent. Policies are typically local and decentralised. This means that +# the policy for a single agent will output an action for that agent based only on its observation. +# In the MARL literature, this is referred to as **decentralised execution**. +# On the other hand, different formulations exist for the critic, mainly: +# +# - In `MAPPO `_ the critic is centralised and takes as input the global state +# of the system. This can be a global observation or simply the concatenation of the agents' observation. MAPPO +# can be used in contexts where **centralised training** is performed as it needs access to global information. +# - In `IPPO `_ the critic takes as input just the observation of the respective agent, +# exactly like the policy. This allows **decentralised training** as both the critic and the policy will only need local +# information to compute their outputs. +# +# Centralised critics help overcome the non-stationary of multiple agents learning concurrently, but, +# on the other hand, they may be impacted by their large input space. +# In this tutorial, we will be able to train both formulations, and we will also discuss how +# parameter-sharing (the practice of sharing the network parameters across the agents) impacts each. +# +# This tutorial is structured as follows: +# +# 1. First, we will define a set of hyperparameters we will be using. +# +# 2. Next, we will create a vectorized multi-agent environment, using TorchRL's +# wrapper for the VMAS simulator. +# +# 3. Next, we will design the policy and the critic networks, discussing the impact of the various choices on +# parameter sharing and critic centralisation. +# +# 4. Next, we will create the sampling collector and the replay buffer. +# +# 5. Finally, we will run our training loop and analyse the results. +# +# If you are running this in Colab or in a machine with a GUI, you will also have the option +# to render and visualise your own trained policy prior and after training. +# +# Let's import our dependencies +# + +# Torch +import torch + +# Tensordict modules +from tensordict.nn import TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor + +# Data collection +from torchrl.collectors import SyncDataCollector +from torchrl.data.replay_buffers import ReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage + +# Env +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import check_env_specs + +# Multi-agent network +from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal + +# Loss +from torchrl.objectives import ClipPPOLoss, ValueEstimators + +# Utils +torch.manual_seed(0) +from matplotlib import pyplot as plt +from tqdm import tqdm + + +###################################################################### +# Define Hyperparameters +# ---------------------- +# +# We set the hyperparameters for our tutorial. +# Depending on the resources +# available, one may choose to execute the policy and the simulator on GPU or on another +# device. +# You can tune some of these values to adjust the computational requirements. +# + +# Devices +device = "cpu" if not torch.has_cuda else "cuda:0" # The divice where learning is run +vmas_device = device # The device where the simulator is run (VMAS can run on GPU) + +# Sampling +frames_per_batch = 6_000 # Number of team frames collected per training iteration +n_iters = 10 # Number of sampling and training iterations +total_frames = frames_per_batch * n_iters + +# Training +num_epochs = 30 # Number of optimization steps per training iteration +minibatch_size = 400 # Size of the mini-batches in each optimization step +lr = 3e-4 # Learning rate +max_grad_norm = 1.0 # Maximum norm for the gradients + +# PPO +clip_epsilon = 0.2 # clip value for PPO loss +gamma = 0.9 # discount factor +lmbda = 0.9 # lambda for generalised advantage estimation +entropy_eps = 1e-4 # coefficient of the entropy term in the PPO loss + +###################################################################### +# Environment +# ----------- +# +# Multi-agent environments simulate multiple agents interacting with the world. +# TorchRL API allows integrating various types of multi-agent environment flavours. +# Some examples include environments with shared or individual agent rewards, done flags, and observations. +# For more information on how the multi-agent environments API works in TorchRL, you can check out the dedicated +# `doc section `_. +# +# The VMAS simulator, in particular, models agents with individual rewards, info, observations, and actions, but +# with a collective done flag. +# Furthermore, it uses *vectorization* to perform simulation in a batch. +# This means that all its state and physics +# are PyTorch tensors with a first dimension representing the number of parallel environments in a batch. +# This allows leveraging the Single Instruction Multiple Data (SIMD) paradigm of GPUs and significantly +# speed up parallel computation by leveraging parallelisation in GPU warps. It also means +# that, when using it in TorchRL, both simulation and training can be run on-device, without ever passing +# data to the CPU. +# +# The multi-agent task we will solve today is *Navigation* (see animated figure above). +# In *Navigation*, randomly spawned agents +# (circles with surrounding dots) need to navigate +# to randomly spawned goals (smaller circles). +# Agents need to use LIDARs (dots around them) to +# avoid colliding into each other. +# Agents act in a 2D continuous world with drag and elastic collisions. +# Their actions are 2D continuous forces which determine their acceleration. +# The reward is composed of three terms: a collision penalisation, a reward based on the distance to the goal, and a +# final shared reward given when all agents reach their goal. +# The distance-based term is computed as the difference in the relative distance +# between an agent and its goal over two consecutive timesteps. +# Each agent observes its position, +# velocity, lidar readings, and relative position to its goal. +# +# We will now instantiate the environment. +# For this tutorial, we will limit the episodes to ``max_steps``, after which the done flag is set. This is +# functionality is already provided in the VMAS simulator but the TorchRL :class:`~.envs.transforms.StepCount` +# transform could alternatively be used. +# We will also use ``num_vmas_envs`` vectorized environments, to leverage batch simulation. +# +# + +max_steps = 100 # Episode steps before done +num_vmas_envs = ( + frames_per_batch // max_steps +) # Number of vectorized envs. frames_per_batch should be divisible by this number +scenario_name = "navigation" +n_agents = 3 + +env = VmasEnv( + scenario=scenario_name, + num_envs=num_vmas_envs, + continuous_actions=True, # VMAS supports both continuous and discrete actions + max_steps=max_steps, + device=vmas_device, + # Scenario kwargs + n_agents=n_agents, # These are custom kwargs that change for each VMAS scenario, see the VMAS repo to know more. +) + +###################################################################### +# The environment is not only defined by its simulator and transforms, but also +# by a series of metadata that describe what can be expected during its +# execution. +# For efficiency purposes, TorchRL is quite stringent when it comes to +# environment specs, but you can easily check that your environment specs are +# adequate. +# In our example, the :class:`~.envs.libs.vmas.VmasEnv` takes care of setting the proper specs for your env so +# you should not have to care about this. +# +# There are four specs to look at: +# +# - ``action_spec`` defines the action space; +# - ``reward_spec`` defines the reward domain; +# - ``done_spec`` defines the done domain; +# - ``observation_spec`` which defines the domain of all other outputs from environmnet steps; +# +# + +print("action_spec:", env.full_action_spec) +print("reward_spec:", env.full_reward_spec) +print("done_spec:", env.full_done_spec) +print("observation_spec:", env.observation_spec) + +###################################################################### +# Using the commands just shown we can access the domain of each value. +# Doing this we can see that all specs apart from done have a leading shape ``(num_vmas_envs, n_agents)``. +# This represents the fact that those values will be present for each agent in each individual environment. +# The done spec, on the other hand, has leading shape ``num_vmas_envs``, representing that done is shared among +# agents. +# +# TorchRL has a way to keep track of which MARL specs are shared and which are not. +# In fact, specs that have the additional agent dimension +# (i.e., they vary for each agent) will be contained in a inner "agents" key. +# +# As you can see the reward and action spec present the "agent" key, +# meaning that entries in tensordicts belonging to those specs will be nested in an "agents" tensordict, +# grouping all per-agent values. +# +# To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the +# respective keys, and +# we will immediately understand which are per-agent and which shared. +# This info will be useful in order to tell all other TorchRL components where to find each value +# + +print("action_keys:", env.action_keys) +print("reward_keys:", env.reward_keys) +print("done_keys:", env.done_keys) + + +###################################################################### +# Transforms +# ~~~~~~~~~~ +# +# We can append any TorchRL transform we need to our enviornment. +# These will modify its input/output in some desired way. +# We stress that, in multi-agent contexts, it is paramount to provide explicitly the keys to modify. +# +# For example, in this case, we will instantiate a ``RewardSum`` transform which will sum rewards over the episode. +# We will tell this transform where to find the reward key and where to write the summed episode reward. +# The transformed environment will inherit +# the device and meta-data of the wrapped environment, and transform these depending on the sequence +# of transforms it contains. +# + + +env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), +) + + +###################################################################### +# the :func:`check_env_specs` function runs a small rollout and compares its output against the environment +# specs. If no error is raised, we can be confident that the specs are properly defined: +# +check_env_specs(env) + +###################################################################### +# Rollout +# ~~~~~~~ +# +# For fun, let's see what a simple random rollout looks like. You can +# call `env.rollout(n_steps)` and get an overview of what the environment inputs +# and outputs look like. Actions will automatically be drawn at random from the action spec +# domain. +# +n_rollout_steps = 5 +rollout = env.rollout(n_rollout_steps) +print("rollout of three steps:", rollout) +print("Shape of the rollout TensorDict:", rollout.batch_size) +###################################################################### +# We can see that our rollout has ``batch_size`` of ``(num_vmas_envs, n_rollout_steps)``. +# This means that all the tensors in it will have those leading dimensions. +# +# Looking more in depth, we can see that the output tensordict can be divided in the following way: +# +# - *In the root* (accessible by running ``rollout.exclude("next")`` ) we will find all the keys that are available +# after a reset is called at the first timestep. We can see their evolution through the rollout steps by indexing +# the ``n_rollout_steps`` dimension. Among these keys, we will find the ones that are different for each agent +# in the ``rollout["agents"]`` tensordict, which will have batch size ``(num_vmas_envs, n_rollout_steps, n_agents)`` +# signifying that it is storing the additional agent dimension. The ones outside this agent tensordict +# will be the shared ones (in this case only done). +# - *In the next* (accessible by running ``rollout.get("next")`` ). We will find the same structure as the root, +# but for keys that are available only after a step. +# +# In TorchRL the convention is that done and observations will be present in both root and next (as these are +# available both at reset time and after a step). Action will only be available in root (as there is no action +# resulting from a step) and reward will only be available in next (as there is no reward at reset time). +# This structure follows the one in **Reinforcement Learning: An Introduction (Sutton and Barto)** where root represents data at time :math:`t` and +# next represents data at time :math:`t+1` of a world step. +# +# +# Render a random rollout +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# If you are on Google Colab, or on a machine with OpenGL and a GUI, you can actually render a random rollout. +# This will give you an idea of what a random policy will achieve in this task, in order to compare it +# with the policy you will train yourself! +# +# To render a rollout, follow the instructions in the *Render* section at the end of this tutorial +# and just remove the line ``policy=policy`` from ``env.rollout()`` . +# +# +# Policy +# ------ +# +# PPO utilises a stochastic policy to handle exploration. This means that our +# neural network will have to output the parameters of a distribution, rather +# than a single value corresponding to the action taken. +# +# As the data is continuous, we use a Tanh-Normal distribution to respect the +# action space boundaries. TorchRL provides such distribution, and the only +# thing we need to care about is to build a neural network that outputs the +# right number of parameters. +# +# In this case, each agent's action will be represented by a 2-dimensional independent normal distribution. +# For this, our neural network will have to output a mean and a standard deviation for each action. +# Each agent will thus have ``2 * n_actions_per_agents`` outputs. +# +# Another important decision we need to make is whether we want our agents to **share the policy parameters**. +# On the one hand, sharing parameters means that they will all share the same policy, which will allow them to benefit from +# each other's experiences. This will also result in faster training. +# On the other hand, it will make them behaviorally *homogenous*, as they will in fact share the same model. +# For this example, we will enable sharing as we do not mind the homogeneity and can benefit from the computational +# speed, but it is important to always think about this decision in your own problems! +# +# We design the policy in three steps. +# +# **First**: define a neural network ``n_obs_per_agent`` -> ``2 * n_actions_per_agents`` +# +# For this we use the ``MultiAgentMLP``, a TorchRL module made exactly for +# multiple agents, with much customisation available. +# + +share_parameters_policy = True + +policy_net = torch.nn.Sequential( + MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[ + -1 + ], # n_obs_per_agent + n_agent_outputs=2 * env.action_spec.shape[-1], # 2 * n_actions_per_agents + n_agents=env.n_agents, + centralised=False, # the policies are decentralised (ie each agent will act from its observation) + share_params=share_parameters_policy, + device=device, + depth=2, + num_cells=256, + activation_class=torch.nn.Tanh, + ), + NormalParamExtractor(), # this will just separate the last dimension into two outputs: a loc and a non-negative scale +) + +###################################################################### +# **Second**: wrap the neural network in a :class:`TensorDictModule` +# +# This is simply a module that will read the ``in_keys`` from a tensordict, feed them to the +# neural networks, and write the +# outputs in-place at the ``out_keys``. +# +# Note that we use ``("agents", ...)`` keys as these keys are denoting data with the +# additional ``n_agents`` dimension. +# +policy_module = TensorDictModule( + policy_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "loc"), ("agents", "scale")], +) + +###################################################################### +# **Third**: wrap the :class:`TensorDictModule` in a :class:`ProbabilisticActor` +# +# We now need to build a distribution out of the location and scale of our +# normal distribution. To do so, we instruct the :class:`ProbabilisticActor` +# class to build a :class:`TanhNormal` out of the location and scale +# parameters. We also provide the minimum and maximum values of this +# distribution, which we gather from the environment specs. +# +# The name of the ``in_keys`` (and hence the name of the ``out_keys`` from +# the :class:`TensorDictModule` above) has to end with the +# :class:`TanhNormal` distribution constructor keyword arguments (loc and scale). +# + +policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "loc"), ("agents", "scale")], + out_keys=[env.action_key], + distribution_class=TanhNormal, + distribution_kwargs={ + "min": env.unbatched_action_spec[env.action_key].space.low, + "max": env.unbatched_action_spec[env.action_key].space.high, + }, + return_log_prob=True, + log_prob_key=("agents", "sample_log_prob"), +) # we'll need the log-prob for the PPO loss + +###################################################################### +# Critic network +# -------------- +# +# The critic network is a crucial component of the PPO algorithm, even though it +# isn't used at sampling time. This module will read the observations and +# return the corresponding value estimates. +# +# As before, one should think carefully about the decision of **sharing the critic parameters**. +# In general, parameter sharing will grant faster training convergence, but there are a few important +# considerations to be made: +# +# - Sharing is not recommended when agents have different reward functions, as the critics will need to learn +# to assign different values to the same state (e.g., in mixed cooperative-competitive settings). +# - In decentralised training settings, sharing cannot be performed without additional infrastructure to +# synchronise parameters. +# +# In all other cases where the reward function (to be differentiated from the reward) is the same for all agents +# (as in the current scenario), +# sharing can provide improved performance. This can come at the cost of homogeneity in the agent strategies. +# In general, the best way to know which choice is preferable is to quickly experiment both options. +# +# Here is also where we have to choose between **MAPPO and IPPO**: +# +# - With MAPPO, we will obtain a central critic with full-observability +# (i.e., it will take all the concatenated agent observations as input). +# We can do this because we are in a simulator +# and training is centralised. +# - With IPPO, we will have a local decentralised critic, just like the policy. +# +# In any case, the critic output will have shape ``(..., n_agents, 1)``. +# If the critic is centralised and shared, +# all the values along the ``n_agents`` dimension will be identical. +# + +share_parameters_critic = True +mappo = True # IPPO if False + +critic_net = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=1, # 1 value per agent + n_agents=env.n_agents, + centralised=mappo, + share_params=share_parameters_critic, + device=device, + depth=2, + num_cells=256, + activation_class=torch.nn.Tanh, +) + +critic = TensorDictModule( + module=critic_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "state_value")], +) + +###################################################################### +# Let us try our policy and critic modules. As pointed earlier, the usage of +# :class:`TensorDictModule` makes it possible to directly read the output +# of the environment to run these modules, as they know what information to read +# and where to write it: +# +# **From this point on, the multi-agent-specific components have been instantiated, and we will simply use the same +# components as in single-agent learning. Isn't this fantastic?** +# +print("Running policy:", policy(env.reset())) +print("Running value:", critic(env.reset())) + +###################################################################### +# Data collector +# -------------- +# +# TorchRL provides a set of data collector classes. Briefly, these +# classes execute three operations: reset an environment, compute an action +# using the policy and the latest observation, execute a step in the environment, and repeat +# the last two steps until the environment signals a stop (or reaches a done +# state). +# +# We will use the simplest possible data collector, which has the same output as an environment rollout, +# with the only difference that it will auto reset done states until the desired frames are collected. +# +collector = SyncDataCollector( + env, + policy, + device=vmas_device, + storing_device=device, + frames_per_batch=frames_per_batch, + total_frames=total_frames, +) + +###################################################################### +# Replay buffer +# ------------- +# +# Replay buffers are a common building piece of off-policy RL algorithms. +# In on-policy contexts, a replay buffer is refilled every time a batch of +# data is collected, and its data is repeatedly consumed for a certain number +# of epochs. +# +# Using a replay buffer for PPO is not mandatory and we could simply +# use the collected data online, but using these classes +# makes it easy for us to build the inner training loop in a reproducible way. +# + +replay_buffer = ReplayBuffer( + storage=LazyTensorStorage( + frames_per_batch, device=device + ), # We store the frames_per_batch collected at each iteration + sampler=SamplerWithoutReplacement(), + batch_size=minibatch_size, # We will sample minibatches of this size +) + +###################################################################### +# Loss function +# ------------- +# +# The PPO loss can be directly imported from TorchRL for convenience using the +# :class:`~.objectives.ClipPPOLoss` class. This is the easiest way of utilising PPO: +# it hides away the mathematical operations of PPO and the control flow that +# goes with it. +# +# PPO requires some "advantage estimation" to be computed. In short, an advantage +# is a value that reflects an expectancy over the return value while dealing with +# the bias / variance tradeoff. +# To compute the advantage, one just needs to (1) build the advantage module, which +# utilises our value operator, and (2) pass each batch of data through it before each +# epoch. +# The GAE module will update the input :class:`TensorDict` with new ``"advantage"`` and +# ``"value_target"`` entries. +# The ``"value_target"`` is a gradient-free tensor that represents the empirical +# value that the value network should represent with the input observation. +# Both of these will be used by :class:`ClipPPOLoss` to +# return the policy and value losses. +# + +loss_module = ClipPPOLoss( + actor=policy, + critic=critic, + clip_epsilon=clip_epsilon, + entropy_coef=entropy_eps, + normalize_advantage=False, # Important to avoid normalizing across the agent dimension +) +loss_module.set_keys( # We have to tell the loss where to find the keys + reward=env.reward_key, + action=env.action_key, + sample_log_prob=("agents", "sample_log_prob"), + value=("agents", "state_value"), + # These last 2 keys will be expanded to match the reward shape + done=("agents", "done"), + terminated=("agents", "terminated"), +) + + +loss_module.make_value_estimator( + ValueEstimators.GAE, gamma=gamma, lmbda=lmbda +) # We build GAE +GAE = loss_module.value_estimator + +optim = torch.optim.Adam(loss_module.parameters(), lr) + +###################################################################### +# Training loop +# ------------- +# We now have all the pieces needed to code our training loop. +# The steps include: +# +# * Collect data +# * Compute advantage +# * Loop over epochs +# * Loop over minibatches to compute loss values +# * Back propagate +# * Optimise +# * Repeat +# * Repeat +# * Repeat +# * Repeat +# +# + +pbar = tqdm(total=n_iters, desc="episode_reward_mean = 0") + +episode_reward_mean_list = [] +for tensordict_data in collector: + tensordict_data.set( + ("next", "agents", "done"), + tensordict_data.get(("next", "done")) + .unsqueeze(-1) + .expand(tensordict_data.get_item_shape(("next", env.reward_key))), + ) + tensordict_data.set( + ("next", "agents", "terminated"), + tensordict_data.get(("next", "terminated")) + .unsqueeze(-1) + .expand(tensordict_data.get_item_shape(("next", env.reward_key))), + ) + # We need to expand the done and terminated to match the reward shape (this is expected by the value estimator) + + with torch.no_grad(): + GAE( + tensordict_data, + params=loss_module.critic_params, + target_params=loss_module.target_critic_params, + ) # Compute GAE and add it to the data + + data_view = tensordict_data.reshape(-1) # Flatten the batch size to shuffle data + replay_buffer.extend(data_view) + + for _ in range(num_epochs): + for _ in range(frames_per_batch // minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + + loss_value = ( + loss_vals["loss_objective"] + + loss_vals["loss_critic"] + + loss_vals["loss_entropy"] + ) + + loss_value.backward() + + torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_grad_norm + ) # Optional + + optim.step() + optim.zero_grad() + + collector.update_policy_weights_() + + # Logging + done = tensordict_data.get(("next", "agents", "done")) + episode_reward_mean = ( + tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item() + ) + episode_reward_mean_list.append(episode_reward_mean) + pbar.set_description(f"episode_reward_mean = {episode_reward_mean}", refresh=False) + pbar.update() + +###################################################################### +# Results +# ------- +# +# Let's plot the mean reward obtained per episode +# +# To make training last longer, increase the ``n_iters`` hyperparameter. +# +plt.plot(episode_reward_mean_list) +plt.xlabel("Training iterations") +plt.ylabel("Reward") +plt.title("Episode reward mean") +plt.show() + +###################################################################### +# Render +# ------ +# +# If you are running this in a machine with GUI, you can render the trained policy by running: +# +# .. code-block:: python +# +# with torch.no_grad(): +# env.rollout( +# max_steps=max_steps, +# policy=policy, +# callback=lambda env, _: env.render(), +# auto_cast_to_device=True, +# break_when_any_done=False, +# ) +# +# If you are running this in Google Colab, you can render the trained policy by running: +# +# .. code-block:: bash +# +# !apt-get update +# !apt-get install -y x11-utils +# !apt-get install -y xvfb +# !pip install pyvirtualdisplay +# +# .. code-block:: python +# +# import pyvirtualdisplay +# display = pyvirtualdisplay.Display(visible=False, size=(1400, 900)) +# display.start() +# from PIL import Image +# +# def rendering_callback(env, td): +# env.frames.append(Image.fromarray(env.render(mode="rgb_array"))) +# env.frames = [] +# with torch.no_grad(): +# env.rollout( +# max_steps=max_steps, +# policy=policy, +# callback=rendering_callback, +# auto_cast_to_device=True, +# break_when_any_done=False, +# ) +# env.frames[0].save( +# f"{scenario_name}.gif", +# save_all=True, +# append_images=env.frames[1:], +# duration=3, +# loop=0, +# ) +# +# from IPython.display import Image +# Image(open(f"{scenario_name}.gif", "rb").read()) +# + + +###################################################################### +# Conclusion and next steps +# ------------------------- +# +# In this tutorial, we have seen: +# +# - How to create a multi-agent environment in TorchRL, how its specs work, and how it integrates with the library; +# - How you use GPU vectorized environments in TorchRL; +# - How to create different multi-agent network architectures in TorchRL (e.g., using parameter sharing, centralised critic) +# - How we can use :class:`tensordict.TensorDict` to carry multi-agent data; +# - How we can tie all the library components (collectors, modules, replay buffers, and losses) in a multi-agent MAPPO/IPPO training loop. +# +# Now that you are proficient with multi-agent PPO, you can check out all +# `TorchRL multi-agent examples `__. +# These are code-only scripts of many popular MARL algorithms such as the ones seen in this tutorial, +# QMIX, MADDPG, IQL, and many more! +# +# If you are interested in creating or wrapping your own multi-agent environments in TorchRL, +# you can check out the dedicated +# `doc section `_. +# +# Finally, you can modify the parameters of this tutorial to try many other configurations and scenarios +# to become a MARL master. +# Here are a few videos of some possible scenarios you can try in VMAS. +# +# .. figure:: https://github.com/matteobettini/vmas-media/blob/main/media/VMAS_scenarios.gif?raw=true +# :alt: VMAS scenarios +# +# Scenarios available in `VMAS `__ +# diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 17f41430217..4fa160ff12f 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -28,7 +28,7 @@ - Transforming your environment inputs and outputs, and writing your own transforms; - How to use :class:`tensordict.TensorDict` to carry arbitrary data structures - from sep to step. + from step to step. In the process, we will touch three crucial components of TorchRL: @@ -241,16 +241,13 @@ def _step(tensordict): new_th = th + new_thdot * dt reward = -costs.view(*tensordict.shape, 1) done = torch.zeros_like(reward, dtype=torch.bool) - # The output must be written in a ``"next"`` entry out = TensorDict( { - "next": { - "th": new_th, - "thdot": new_thdot, - "params": tensordict["params"], - "reward": reward, - "done": done, - } + "th": new_th, + "thdot": new_thdot, + "params": tensordict["params"], + "reward": reward, + "done": done, }, tensordict.shape, ) @@ -389,14 +386,14 @@ def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] self.observation_spec = CompositeSpec( th=BoundedTensorSpec( - minimum=-torch.pi, - maximum=torch.pi, + low=-torch.pi, + high=torch.pi, shape=(), dtype=torch.float32, ), thdot=BoundedTensorSpec( - minimum=-td_params["params", "max_speed"], - maximum=td_params["params", "max_speed"], + low=-td_params["params", "max_speed"], + high=td_params["params", "max_speed"], shape=(), dtype=torch.float32, ), @@ -411,8 +408,8 @@ def _make_spec(self, td_params): # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported self.action_spec = BoundedTensorSpec( - minimum=-td_params["params", "max_torque"], - maximum=td_params["params", "max_torque"], + low=-td_params["params", "max_torque"], + high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) @@ -440,7 +437,7 @@ def make_composite_from_td(td): # Reproducible experiments: seeding # --------------------------------- # -# Seeding an environment is a commong operation when initializing an experiment. +# Seeding an environment is a common operation when initializing an experiment. # :func:`EnvBase._set_seed` only goal is to set the seed of the contained # simulator. If possible, this operation should not call `reset()` or interact # with the environment execution. The parent :func:`EnvBase.set_seed` method @@ -661,8 +658,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: @_apply_to_composite def transform_observation_spec(self, observation_spec): return BoundedTensorSpec( - minimum=-1, - maximum=1, + low=-1, + high=1, shape=observation_spec.shape, dtype=observation_spec.dtype, device=observation_spec.device, @@ -679,8 +676,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: @_apply_to_composite def transform_observation_spec(self, observation_spec): return BoundedTensorSpec( - minimum=-1, - maximum=1, + low=-1, + high=1, shape=observation_spec.shape, dtype=observation_spec.dtype, device=observation_spec.device, diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index a4e9bf4930b..ccb1c9d4ea7 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -45,7 +45,7 @@ # The list of available environment can be accessed through this command: # -GymEnv.available_envs[:10] +list(GymEnv.available_envs)[:10] ############################################################################### # Env Specs @@ -60,7 +60,6 @@ print("Env observation_spec: \n", env.observation_spec) print("Env action_spec: \n", env.action_spec) print("Env reward_spec: \n", env.reward_spec) -print("Env done_spec: \n", env.done_spec) ############################################################################### # Those spec come with a series of useful tools: one can assert whether a @@ -76,6 +75,18 @@ print("random action: \n", env.action_spec.rand()) +############################################################################### +# Out of these specs, the ``done_spec`` deserves a special attention. In TorchRL, +# all environments write end-of-trajectory signals of at least two types: +# ``"terminated"`` (indicating that the Markov Decision Process has reached +# a final state - the __episode__ is finished) and ``"done"``, indicating that +# this is the last step of a __trajectory__ (but not necessarily the end of +# the task). In general, a ``"done"`` entry that is ``True`` when a ``"terminal"`` +# is ``False`` is caused by a ``"truncated"`` signal. Gym environments account for +# these three signals: + +print(env.done_spec) + ############################################################################### # Envs are also packed with an ``env.state_spec`` attribute of type # ``CompositeSpec`` which contains all the specs that are inputs to the env diff --git a/version.txt b/version.txt index 17e51c385ea..0ea3a944b39 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.1 +0.2.0