diff --git a/.circleci/config.yml b/.circleci/config.yml index 5e878991b8e..7246e3a945c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -375,7 +375,7 @@ jobs: image: ubuntu-2004-cuda-11.4:202110-01 resource_class: gpu.nvidia.medium environment: - image_name: "pytorch/manylinux-cuda117" + image_name: "nvidia/cudagl:11.4.0-base" TAR_OPTIONS: --no-same-owner PYTHON_VERSION: << parameters.python_version >> CU_VERSION: << parameters.cu_version >> diff --git a/.circleci/unittest/linux_examples/scripts/install.sh b/.circleci/unittest/linux_examples/scripts/install.sh index a7c9bb93976..3205f7496ee 100755 --- a/.circleci/unittest/linux_examples/scripts/install.sh +++ b/.circleci/unittest/linux_examples/scripts/install.sh @@ -4,6 +4,8 @@ unset PYTORCH_VERSION # For unittest, nightly PyTorch is used as the following section, # so no need to set PYTORCH_VERSION. # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev set -e diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index cc57b730be8..cfb32bc80bd 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -8,6 +8,8 @@ set -e +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 wget freeglut3 freeglut3-dev + eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env @@ -27,19 +29,75 @@ export MKL_THREADING_LAYER=GNU python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 20 python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20 +# ======================================================================================== +# DDPG +# ---- +# +# Modalities: +# ^^^^^^^^^^^ +# +# pixels on/off +# Batched on/off +# # With batched environments python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=4 \ - env_per_collector=2 \ - collector_devices=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=2 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + recorder.video=True \ + recorder.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=False \ + logger.backend=csv +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=2 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + recorder.video=True \ + recorder.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=True \ + logger.backend=csv +# With single envs +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=1 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + recorder.video=True \ + recorder.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=False \ + logger.backend=csv +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=1 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + recorder.video=True \ + recorder.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=True \ + logger.backend=csv + python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ total_frames=48 \ batch_size=10 \ @@ -112,19 +170,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/drea buffer_size=120 \ rssm_hidden_dim=17 -# With single envs -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=2 \ - env_per_collector=1 \ - collector_devices=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ total_frames=48 \ batch_size=10 \ diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh index c79f25a6979..007d57572a4 100755 --- a/.circleci/unittest/linux_examples/scripts/setup_env.sh +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -9,6 +9,8 @@ set -e this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ + git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" conda_dir="${root_dir}/conda" @@ -71,18 +73,6 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ MUJOCO_GL=$PRIVATE_MUJOCO_GL \ PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL -# Software rendering requires GLX and OSMesa. -if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then - yum makecache - yum install -y glfw - yum install -y glew - yum install -y mesa-libGL - yum install -y mesa-libGL-devel - yum install -y mesa-libOSMesa-devel - yum -y install egl-utils - yum -y install freeglut -fi - pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune @@ -90,29 +80,5 @@ conda env update --file "${this_dir}/environment.yml" --prune conda deactivate conda activate "${env_dir}" -if [[ $OSTYPE != 'darwin'* ]]; then - # install ale-py: manylinux names are broken for CentOS so we need to manually download and - # rename them - PY_VERSION=$(python --version) - if [[ $PY_VERSION == *"3.7"* ]]; then - wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.8"* ]]; then - wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.9"* ]]; then - wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.10"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - fi - pip install "gymnasium[atari,accept-rom-license]" -else - pip install "gymnasium[atari,accept-rom-license]" -fi +pip install ale-py +pip install "gymnasium[atari,accept-rom-license]" diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 5ad3912c0ef..b4091a2b70d 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -1,36 +1,65 @@ -env_name: HalfCheetah-v4 -env_task: "" -env_library: gym -async_collection: 1 -record_video: 0 -normalize_rewards_online: 1 -normalize_rewards_online_scale: 5 -frame_skip: 1 -frames_per_batch: 1024 -optim_steps_per_batch: 128 -batch_size: 256 -total_frames: 1000000 -prb: 1 -lr: 3e-4 -ou_exploration: 1 -multi_step: 1 -init_random_frames: 25000 -activation: elu -gSDE: 0 -from_pixels: 0 -#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] -collector_devices: [cpu,cpu,cpu,cpu] -env_per_collector: 8 -num_workers: 32 -lr_scheduler: "" -value_network_update_interval: 200 -record_interval: 10 -max_frames_per_traj: -1 -weight_decay: 0.0 -annealing_frames: 1000000 -init_env_steps: 10000 -record_frames: 10000 -loss_function: smooth_l1 -batch_transform: 1 -buffer_prefetch: 64 -norm_stats: 1 +# task and env +env: + env_name: HalfCheetah-v4 + env_task: "" + env_library: gym + normalize_rewards_online: 1 + normalize_rewards_online_scale: 5 + frame_skip: 1 + norm_stats: 1 + num_envs: 4 + n_samples_stats: 1000 + noop: 1 + reward_scaling: + from_pixels: False + +# collector +collector: + async_collection: 1 + frames_per_batch: 1024 + total_frames: 1000000 + multi_step: 3 # 0 to disable + init_random_frames: 25000 + collector_devices: cpu # [cpu,cpu,cpu,cpu] + num_collectors: 4 + max_frames_per_traj: -1 + +# eval +recorder: + video: True + interval: 10000 # record interval in frames + frames: 10000 + +# logger +logger: + backend: wandb + exp_name: ddpg_cheetah_gym + +# Buffer +replay_buffer: + prb: 1 + buffer_prefetch: 64 + capacity: 1_000_000 + +# Optim +optim: + device: cpu + lr: 3e-4 + weight_decay: 0.0 + batch_size: 256 + lr_scheduler: "" + value_network_update_interval: 200 + optim_steps_per_batch: 8 + +# Policy and model +model: + ou_exploration: 1 + annealing_frames: 1000000 + noisy: False + activation: elu + +# loss +loss: + loss_function: smooth_l1 + gamma: 0.99 + tau: 0.05 diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index aed849cd6b5..d68b984aa57 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -2,196 +2,104 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""DDPG Example. -import dataclasses +This is a self-contained example of a DDPG training script. + +It works across Gym and DM-control over a variety of tasks. + +Both state and pixel-based environments are supported. + +The helper functions are coded in the utils.py associated with this script. + +""" import hydra -import torch.cuda -from hydra.core.config_store import ConfigStore -from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv -from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import OrnsteinUhlenbeckProcessWrapper -from torchrl.record import VideoRecorder -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.collectors import ( - make_collector_offpolicy, - OffPolicyCollectorConfig, +import tqdm +from torchrl.trainers.helpers.envs import correct_for_frame_skip + +from utils import ( + get_stats, + make_collector, + make_ddpg_model, + make_logger, + make_loss, + make_optim, + make_policy, + make_recorder, + make_replay_buffer, ) -from torchrl.trainers.helpers.envs import ( - correct_for_frame_skip, - EnvConfig, - initialize_observation_norm_transforms, - parallel_env_constructor, - retrieve_observation_norms_state_dict, - transformed_env_constructor, -) -from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss -from torchrl.trainers.helpers.models import DDPGModelConfig, make_ddpg_actor -from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig -from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig - -config_fields = [ - (config_field.name, config_field.type, config_field) - for config_cls in ( - TrainerConfig, - OffPolicyCollectorConfig, - EnvConfig, - LossConfig, - DDPGModelConfig, - LoggerConfig, - ReplayArgsConfig, - ) - for config_field in dataclasses.fields(config_cls) -] -Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) -cs = ConfigStore.instance() -cs.store(name="config", node=Config) - -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - - -@hydra.main(version_base=None, config_path=".", config_name="config") + + +@hydra.main(config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 cfg = correct_for_frame_skip(cfg) - - if not isinstance(cfg.reward_scaling, float): - cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0) - - device = ( - torch.device("cpu") - if torch.cuda.device_count() == 0 - else torch.device("cuda:0") - ) - - exp_name = generate_exp_name("DDPG", cfg.exp_name) - logger = get_logger( - logger_type=cfg.logger, logger_name="ddpg_logging", experiment_name=exp_name - ) - video_tag = exp_name if cfg.record_video else "" - - key, init_env_steps, stats = None, None, None - if not cfg.vecnorm and cfg.norm_stats: - if not hasattr(cfg, "init_env_steps"): - raise AttributeError("init_env_steps missing from arguments.") - key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") - init_env_steps = cfg.init_env_steps - stats = {"loc": None, "scale": None} - elif cfg.from_pixels: - stats = {"loc": 0.5, "scale": 0.5} - - proof_env = transformed_env_constructor( - cfg=cfg, - stats=stats, - use_env_creator=False, - )() - initialize_observation_norm_transforms( - proof_environment=proof_env, num_iter=init_env_steps, key=key - ) - _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] - - model = make_ddpg_actor( - proof_env, - cfg=cfg, - device=device, - ) - loss_module, target_net_updater = make_ddpg_loss(model, cfg) - - actor_model_explore = model[0] - if cfg.ou_exploration: - if cfg.gSDE: - raise RuntimeError("gSDE and ou_exploration are incompatible") - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( - actor_model_explore, - annealing_num_steps=cfg.annealing_frames, - sigma=cfg.ou_sigma, - theta=cfg.ou_theta, - ).to(device) - if device == torch.device("cpu"): - # mostly for debugging - actor_model_explore.share_memory() - - if cfg.gSDE: - with torch.no_grad(), set_exploration_mode("random"): - # get dimensions to build the parallel env - proof_td = actor_model_explore(proof_env.reset().to(device)) - action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:] - del proof_td - else: - action_dim_gsde, state_dim_gsde = None, None - - proof_env.close() - - create_env_fn = parallel_env_constructor( - cfg=cfg, - obs_norm_state_dict=obs_norm_state_dict, - action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde, - ) - - collector = make_collector_offpolicy( - make_env=create_env_fn, - actor_model_explore=actor_model_explore, - cfg=cfg, - # make_env_kwargs=[ - # {"device": device} if device >= 0 else {} - # for device in args.env_rendering_devices - # ], - ) - - replay_buffer = make_replay_buffer(device, cfg) - - recorder = transformed_env_constructor( - cfg, - video_tag=video_tag, - norm_obs_only=True, - obs_norm_state_dict=obs_norm_state_dict, - logger=logger, - use_env_creator=False, - )() - if isinstance(create_env_fn, ParallelEnv): - raise NotImplementedError("This behaviour is deprecated") - elif isinstance(create_env_fn, EnvCreator): - recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict()) - elif isinstance(create_env_fn, TransformedEnv): - recorder.transform = create_env_fn.transform.clone() - else: - raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}") - if logger is not None and video_tag: - recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag)) - - # reset reward scaling - for t in recorder.transform: - if isinstance(t, RewardScaling): - t.scale.fill_(1.0) - t.loc.fill_(0.0) - - trainer = make_trainer( - collector, - loss_module, - recorder, - target_net_updater, - actor_model_explore, - replay_buffer, - logger, - cfg, - ) - - final_seed = collector.set_seed(cfg.seed) - print(f"init seed: {cfg.seed}, final seed: {final_seed}") - - trainer.train() - return (logger.log_dir, trainer._log_dict) + model_device = cfg.optim.device + + state_dict = get_stats(cfg.env) + logger = make_logger(cfg.logger) + replay_buffer = make_replay_buffer(cfg.replay_buffer) + + actor_network, value_network = make_ddpg_model(cfg) + actor_network = actor_network.to(model_device) + value_network = value_network.to(model_device) + + policy = make_policy(cfg.model, actor_network) + collector = make_collector(cfg, state_dict=state_dict, policy=policy) + loss, target_net_updater = make_loss(cfg.loss, actor_network, value_network) + optim = make_optim(cfg.optim, actor_network, value_network) + recorder = make_recorder(cfg, logger, policy) + + optim_steps_per_batch = cfg.optim.optim_steps_per_batch + batch_size = cfg.optim.batch_size + init_random_frames = cfg.collector.init_random_frames + record_interval = cfg.recorder.interval + + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + collected_frames = 0 + + r0 = None + l0 = None + for data in collector: + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + # extend replay buffer + replay_buffer.extend(data.view(-1)) + if collected_frames >= init_random_frames: + for _ in range(optim_steps_per_batch): + # sample + sample = replay_buffer.sample(batch_size) + # loss + loss_vals = loss(sample) + # backprop + loss_val = sum( + val for key, val in loss_vals.items() if key.startswith("loss") + ) + loss_val.backward() + optim.step() + optim.zero_grad() + target_net_updater.step() + if r0 is None: + r0 = data["reward"].mean().item() + if l0 is None: + l0 = loss_val.item() + + for key, value in loss_vals.items(): + logger.log_scalar(key, value.item(), collected_frames) + logger.log_scalar( + "reward_training", data["reward"].mean().item(), collected_frames + ) + + pbar.set_description( + f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['reward'].mean(): 4.4f} (init={r0: 4.4f})" + ) + collector.update_policy_weights_() + if ( + collected_frames - frames_in_batch + ) // record_interval < collected_frames // record_interval: + recorder() if __name__ == "__main__": diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py new file mode 100644 index 00000000000..ce7779a7211 --- /dev/null +++ b/examples/ddpg/utils.py @@ -0,0 +1,457 @@ +from copy import deepcopy + +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule + +from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector +from torchrl.data import ( + CompositeSpec, + LazyMemmapStorage, + MultiStep, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler +from torchrl.envs import ( + CatFrames, + CatTensors, + DoubleToFloat, + EnvCreator, + GrayScale, + NoopResetEnv, + ObservationNorm, + ParallelEnv, + Resize, + RewardScaling, + ToTensorImage, + TransformedEnv, +) +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import ( + AdditiveGaussianWrapper, + DdpgCnnActor, + DdpgCnnQNet, + DdpgMlpActor, + DdpgMlpQNet, + NoisyLinear, + OrnsteinUhlenbeckProcessWrapper, + ProbabilisticActor, + TanhDelta, + ValueOperator, +) +from torchrl.objectives import DDPGLoss, SoftUpdate +from torchrl.record import VideoRecorder +from torchrl.record.loggers import generate_exp_name, get_logger +from torchrl.trainers import Recorder +from torchrl.trainers.helpers.envs import LIBS +from torchrl.trainers.helpers.models import ACTIVATIONS + + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + +# ==================================================================== +# Environment utils +# ----------------- + + +def make_base_env(env_cfg, from_pixels=None): + env_library = LIBS[env_cfg.env_library] + env_name = env_cfg.env_name + frame_skip = env_cfg.frame_skip + if from_pixels is None: + from_pixels = env_cfg.from_pixels + + env_kwargs = { + "env_name": env_name, + "frame_skip": frame_skip, + "from_pixels": from_pixels, # for rendering + "pixels_only": False, + } + if env_library is DMControlEnv: + env_task = env_cfg.env_task + env_kwargs.update({"task_name": env_task}) + env = env_library(**env_kwargs) + if env_cfg.noop > 1: + env = TransformedEnv(env, NoopResetEnv(env_cfg.noop)) + return env + + +def make_transformed_env(base_env, env_cfg): + from_pixels = env_cfg.from_pixels + if from_pixels: + return make_transformed_env_pixels(base_env, env_cfg) + else: + return make_transformed_env_states(base_env, env_cfg) + + +def make_transformed_env_pixels(base_env, env_cfg): + if not isinstance(env_cfg.reward_scaling, float): + env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) + + env_library = LIBS[env_cfg.env_library] + env = TransformedEnv(base_env) + + reward_scaling = env_cfg.reward_scaling + + env.append_transform(RewardScaling(0.0, reward_scaling)) + + double_to_float_list = [] + double_to_float_inv_list = [] + + # + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + + obs_norm = ObservationNorm(in_keys=["pixels"]) + env.append_transform(obs_norm) + + if env_library is DMControlEnv: + double_to_float_list += [ + "reward", + ] + double_to_float_list += [ + "action", + ] + double_to_float_inv_list += ["action"] # DMControl requires double-precision + double_to_float_list += ["observation_vector"] + else: + double_to_float_list += ["observation_vector"] + env.append_transform( + DoubleToFloat( + in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list + ) + ) + return env + + +def make_transformed_env_states(base_env, env_cfg): + if not isinstance(env_cfg.reward_scaling, float): + env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) + + env_library = LIBS[env_cfg.env_library] + env = TransformedEnv(base_env) + + reward_scaling = env_cfg.reward_scaling + + env.append_transform(RewardScaling(0.0, reward_scaling)) + + double_to_float_list = [] + double_to_float_inv_list = [] + + # we concatenate all the state vectors + # even if there is a single tensor, it'll be renamed in "observation_vector" + selected_keys = [ + key for key in env.observation_spec.keys(True, True) if key != "pixels" + ] + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + obs_norm = ObservationNorm(in_keys=[out_key]) + env.append_transform(obs_norm) + + if env_library is DMControlEnv: + double_to_float_list += [ + "reward", + ] + double_to_float_list += [ + "action", + ] + double_to_float_inv_list += ["action"] # DMControl requires double-precision + double_to_float_list += ["observation_vector"] + else: + double_to_float_list += ["observation_vector"] + env.append_transform( + DoubleToFloat( + in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list + ) + ) + return env + + +def make_parallel_env(env_cfg, state_dict): + num_envs = env_cfg.num_envs + env = make_transformed_env( + ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg + ) + for t in env.transform: + if isinstance(t, ObservationNorm): + t.init_stats(3, cat_dim=1, reduce_dim=[0, 1]) + env.load_state_dict(state_dict) + return env + + +def get_stats(env_cfg): + from_pixels = env_cfg.from_pixels + env = make_transformed_env(make_base_env(env_cfg), env_cfg) + init_stats(env, env_cfg.n_samples_stats, from_pixels) + return env.state_dict() + + +def init_stats(env, n_samples_stats, from_pixels): + for t in env.transform: + if isinstance(t, ObservationNorm): + if from_pixels: + t.init_stats( + n_samples_stats, + cat_dim=-3, + reduce_dim=(-1, -2, -3), + keep_dims=(-1, -2, -3), + ) + else: + t.init_stats(n_samples_stats) + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, state_dict, policy): + env_cfg = cfg.env + loss_cfg = cfg.loss + collector_cfg = cfg.collector + if collector_cfg.async_collection: + collector_class = MultiaSyncDataCollector + else: + collector_class = MultiSyncDataCollector + if collector_cfg.multi_step: + ms = MultiStep(gamma=loss_cfg.gamma, n_steps=collector_cfg.multi_step) + else: + ms = None + collector = collector_class( + [make_parallel_env(env_cfg, state_dict=state_dict)] + * collector_cfg.num_collectors, + policy, + frames_per_batch=collector_cfg.frames_per_batch, + total_frames=collector_cfg.total_frames, + postproc=ms, + devices=collector_cfg.collector_devices, + init_random_frames=collector_cfg.init_random_frames, + max_frames_per_traj=collector_cfg.max_frames_per_traj, + ) + return collector + + +def make_replay_buffer(rb_cfg): + if rb_cfg.prb: + sampler = PrioritizedSampler(max_capacity=rb_cfg.capacity, alpha=0.7, beta=0.5) + else: + sampler = RandomSampler() + return TensorDictReplayBuffer( + storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler + ) + + +# ==================================================================== +# Model +# ----- +# +# We give one version of the model for learning from pixels, and one for state. +# TorchRL comes in handy at this point, as the high-level interactions with +# these models is unchanged, regardless of the modality. +# + + +def make_ddpg_model(cfg): + + env_cfg = cfg.env + model_cfg = cfg.model + proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + # we must initialize the observation norm transform + init_stats(proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels) + + env_specs = proof_environment.specs + from_pixels = env_cfg.from_pixels + + if not from_pixels: + actor_net, q_net = make_ddpg_modules_state(model_cfg, proof_environment) + in_keys = ["observation_vector"] + out_keys = ["param"] + else: + actor_net, q_net = make_ddpg_modules_pixels(model_cfg, proof_environment) + in_keys = ["pixels"] + out_keys = ["param", "hidden"] + + actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + + # We use a ProbabilisticActor to make sure that we map the + # network output to the right space using a TanhDelta + # distribution. + actor = ProbabilisticActor( + module=actor_module, + in_keys=["param"], + spec=CompositeSpec(action=env_specs["input_spec"]["action"]), + safe=True, + distribution_class=TanhDelta, + distribution_kwargs={ + "min": env_specs["input_spec"]["action"].space.minimum, + "max": env_specs["input_spec"]["action"].space.maximum, + }, + ) + + if not from_pixels: + in_keys = ["observation_vector", "action"] + else: + in_keys = ["pixels", "action"] + + out_keys = ["state_action_value"] + value = ValueOperator( + in_keys=in_keys, + out_keys=out_keys, + module=q_net, + ) + + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + for t in proof_environment.transform: + if isinstance(t, ObservationNorm): + t.init_stats(2) + td = proof_environment.rollout(max_steps=1000) + print(td) + actor(td) + value(td) + + return actor, value + + +def make_ddpg_modules_state(model_cfg, proof_environment): + + noisy = model_cfg.noisy + + linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear + + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net_default_kwargs = { + "action_dim": out_features, + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + }, + } + actor_net = DdpgMlpActor(**actor_net_default_kwargs) + + # Value model: DdpgMlpQNet is a specialized class that reads the state and + # the action and outputs a value from it. It has two sub-components that + # we parameterize with `mlp_net_kwargs_net1` and `mlp_net_kwargs_net2`. + value_net_default_kwargs1 = { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + "bias_last_layer": True, + } + value_net_default_kwargs2 = { + "num_cells": [400, 300], + "activation_class": ACTIVATIONS[model_cfg.activation], + "bias_last_layer": True, + "layer_class": linear_layer_class, + } + q_net = DdpgMlpQNet( + mlp_net_kwargs_net1=value_net_default_kwargs1, + mlp_net_kwargs_net2=value_net_default_kwargs2, + ) + return actor_net, q_net + + +def make_ddpg_modules_pixels(model_cfg, proof_environment): + noisy = model_cfg.noisy + + linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear + + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net_default_kwargs = { + "action_dim": out_features, + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + }, + "conv_net_kwargs": {"activation_class": ACTIVATIONS[model_cfg.activation]}, + } + actor_net = DdpgCnnActor(**actor_net_default_kwargs) + + value_net_default_kwargs = { + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + } + q_net = DdpgCnnQNet(**value_net_default_kwargs) + + return actor_net, q_net + + +def make_policy(model_cfg, actor): + if model_cfg.ou_exploration: + return OrnsteinUhlenbeckProcessWrapper(actor) + else: + return AdditiveGaussianWrapper(actor) + + +# ==================================================================== +# DDPG Loss +# --------- + + +def make_loss(loss_cfg, actor_network, value_network): + loss = DDPGLoss( + actor_network, + value_network, + gamma=loss_cfg.gamma, + loss_function=loss_cfg.loss_function, + ) + target_net_updater = SoftUpdate(loss, 1 - loss_cfg.tau) + target_net_updater.init_() + return loss, target_net_updater + + +def make_optim(optim_cfg, actor_network, value_network): + optim = torch.optim.Adam( + list(actor_network.parameters()) + list(value_network.parameters()), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + return optim + + +# ==================================================================== +# Logging and recording +# --------------------- + + +def make_logger(logger_cfg): + exp_name = generate_exp_name("DDPG", logger_cfg.exp_name) + logger_cfg.exp_name = exp_name + logger = get_logger( + logger_cfg.backend, logger_name="ddpg", experiment_name=exp_name + ) + return logger + + +def make_recorder(cfg, logger, policy) -> Recorder: + env_cfg = deepcopy(cfg.env) + env = make_transformed_env(make_base_env(env_cfg, from_pixels=True), env_cfg) + if cfg.recorder.video: + env.insert_transform( + 0, VideoRecorder(logger=logger, tag=cfg.logger.exp_name, in_keys=["pixels"]) + ) + return Recorder( + record_interval=1, + record_frames=cfg.recorder.frames, + frame_skip=env_cfg.frame_skip, + policy_exploration=policy, + recorder=env, + exploration_mode="mean", + ) diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index fac07ee0afd..a2842973bfa 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -55,20 +55,18 @@ def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 """ # Adapt all frame counts wrt frame_skip - if cfg.frame_skip != 1: - fields = [ - "max_frames_per_traj", - "total_frames", - "frames_per_batch", - "record_frames", - "annealing_frames", - "init_random_frames", - "init_env_steps", - "noops", - ] - for field in fields: - if hasattr(cfg, field): - setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip) + + frame_skip = cfg.env.frame_skip + + if frame_skip != 1: + cfg.collector.max_frames_per_traj //= frame_skip + cfg.collector.total_frames //= frame_skip + cfg.collector.frames_per_batch //= frame_skip + cfg.collector.init_random_frames //= frame_skip + cfg.collector.init_env_steps //= frame_skip + cfg.recorder.record_frames //= frame_skip + cfg.model.annealing_frames //= frame_skip + cfg.env.noops //= frame_skip return cfg diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 1e3551c3b01..aee521c90e8 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1161,7 +1161,7 @@ def __init__( self.log_pbar = log_pbar @torch.inference_mode() - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase = None) -> Dict: out = None if self._count % self.record_interval == 0: with set_exploration_mode(self.exploration_mode):