diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml index ef9251e96c9..688921f826a 100644 --- a/.github/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -27,3 +27,5 @@ dependencies: - coverage - vmas - transformers + - gym[atari] + - gym[accept-rom-license] diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index d81e90fdd42..f435d3fc732 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -48,18 +48,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans # ==================================================================================== # # ================================ Gymnasium ========================================= # -python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ - env.num_envs=1 \ - env.device=cuda:0 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - optim.device=cuda:0 \ +python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \ + env.env_name=HalfCheetah-v4 \ + collector.total_frames=40 \ + collector.frames_per_batch=20 \ loss.mini_batch_size=10 \ loss.ppo_epochs=1 \ logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False + logger.test_interval=40 +python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + loss.mini_batch_size=20 \ + loss.ppo_epochs=1 \ + logger.backend= \ + logger.test_interval=40 python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -208,18 +211,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ record_video=True \ record_frames=4 \ buffer_size=120 -python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ - env.num_envs=1 \ - env.device=cuda:0 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - optim.device=cuda:0 \ - loss.mini_batch_size=10 \ - loss.ppo_epochs=1 \ - logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ total_frames=48 \ init_random_frames=10 \ diff --git a/examples/ppo/README.md b/examples/ppo/README.md new file mode 100644 index 00000000000..7d27f746e2a --- /dev/null +++ b/examples/ppo/README.md @@ -0,0 +1,29 @@ +## Reproducing Proximal Policy Optimization (PPO) Algorithm Results + +This repository contains scripts that enable training agents using the Proximal Policy Optimization (PPO) Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Schulman et al. (2017) to implement the PPO algorithm but introduce the improvement of computing the Generalised Advantage Estimator (GAE) at every epoch. + + +## Examples Structure + +Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. ppo_atari.py). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml). + + +## Running the Examples + +You can execute the PPO algorithm on Atari environments by running the following command: + +```bash +python ppo_atari.py +``` + +You can execute the PPO algorithm on MuJoCo environments by running the following command: + +```bash +python ppo_mujoco.py +``` diff --git a/examples/ppo/config.yaml b/examples/ppo/config.yaml deleted file mode 100644 index d7840906c92..00000000000 --- a/examples/ppo/config.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# task and env -defaults: - - hydra/job_logging: disabled - -env: - env_name: PongNoFrameskip-v4 - env_task: "" - env_library: gym - frame_skip: 4 - num_envs: 8 - noop: 1 - reward_scaling: 1.0 - from_pixels: True - n_samples_stats: 1000 - device: cuda:0 - -# collector -collector: - frames_per_batch: 4096 - total_frames: 40_000_000 - collector_device: cuda:0 # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: ppo_pong_gym - log_interval: 10000 - -# Optim -optim: - device: cuda:0 - lr: 2.5e-4 - weight_decay: 0.0 - lr_scheduler: True - -# loss -loss: - gamma: 0.99 - mini_batch_size: 1024 - ppo_epochs: 10 - gae_lamdda: 0.95 - clip_epsilon: 0.1 - critic_coef: 0.5 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/ppo/config_atari.yaml b/examples/ppo/config_atari.yaml new file mode 100644 index 00000000000..6957fd9bddd --- /dev/null +++ b/examples/ppo/config_atari.yaml @@ -0,0 +1,36 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + num_envs: 8 + +# collector +collector: + frames_per_batch: 4096 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 2.5e-4 + eps: 1.0e-6 + weight_decay: 0.0 + max_grad_norm: 0.5 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 1024 + ppo_epochs: 3 + gae_lambda: 0.95 + clip_epsilon: 0.1 + anneal_clip_epsilon: True + critic_coef: 1.0 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/ppo/config_example2.yaml b/examples/ppo/config_example2.yaml deleted file mode 100644 index 9d06c8a82ee..00000000000 --- a/examples/ppo/config_example2.yaml +++ /dev/null @@ -1,43 +0,0 @@ -# task and env -env: - env_name: HalfCheetah-v4 - env_task: "" - env_library: gym - frame_skip: 1 - num_envs: 1 - noop: 1 - reward_scaling: 1.0 - from_pixels: False - n_samples_stats: 3 - device: cuda - -# collector -collector: - frames_per_batch: 2048 - total_frames: 1_000_000 - collector_device: cuda # cpu - max_frames_per_traj: -1 - -# logger -logger: - backend: wandb - exp_name: ppo_halfcheetah_gym - log_interval: 10000 - -# Optim -optim: - device: cuda - lr: 3e-4 - weight_decay: 1e-4 - lr_scheduler: False - -# loss -loss: - gamma: 0.99 - mini_batch_size: 64 - ppo_epochs: 10 - gae_lamdda: 0.95 - clip_epsilon: 0.2 - critic_coef: 0.5 - entropy_coef: 0.0 - loss_critic_type: l2 diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml new file mode 100644 index 00000000000..1272c1f4bff --- /dev/null +++ b/examples/ppo/config_mujoco.yaml @@ -0,0 +1,33 @@ +# task and env +env: + env_name: HalfCheetah-v3 + +# collector +collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + +# logger +logger: + backend: wandb + exp_name: Mujoco_Schulman17 + test_interval: 1_000_000 + num_test_episodes: 5 + +# Optim +optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: False + +# loss +loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + anneal_clip_epsilon: False + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py deleted file mode 100644 index 7f532bc0c4d..00000000000 --- a/examples/ppo/ppo.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""PPO Example. - -This is a self-contained example of a PPO training script. - -Both state and pixel-based environments are supported. - -The helper functions are coded in the utils.py associated with this script. -""" -import hydra - - -@hydra.main(config_path=".", config_name="config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 - - import torch - import tqdm - from tensordict import TensorDict - from torchrl.envs.utils import ExplorationType, set_exploration_type - from utils import ( - make_collector, - make_data_buffer, - make_logger, - make_loss, - make_optim, - make_ppo_models, - make_test_env, - ) - - # Correct for frame_skip - cfg.collector.total_frames = cfg.collector.total_frames // cfg.env.frame_skip - cfg.collector.frames_per_batch = ( - cfg.collector.frames_per_batch // cfg.env.frame_skip - ) - mini_batch_size = cfg.loss.mini_batch_size = ( - cfg.loss.mini_batch_size // cfg.env.frame_skip - ) - - model_device = cfg.optim.device - actor, critic, critic_head = make_ppo_models(cfg) - - collector, state_dict = make_collector(cfg, policy=actor) - data_buffer = make_data_buffer(cfg) - loss_module, adv_module = make_loss( - cfg.loss, - actor_network=actor, - value_network=critic, - value_head=critic_head, - ) - optim = make_optim(cfg.optim, loss_module) - - batch_size = cfg.collector.total_frames * cfg.env.num_envs - num_mini_batches = batch_size // mini_batch_size - total_network_updates = ( - (cfg.collector.total_frames // batch_size) - * cfg.loss.ppo_epochs - * num_mini_batches - ) - - scheduler = None - if cfg.optim.lr_scheduler: - scheduler = torch.optim.lr_scheduler.LinearLR( - optim, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1 - ) - - logger = None - if cfg.logger.backend: - logger = make_logger(cfg.logger) - test_env = make_test_env(cfg.env, state_dict) - record_interval = cfg.logger.log_interval - pbar = tqdm.tqdm(total=cfg.collector.total_frames) - collected_frames = 0 - - # Main loop - r0 = None - l0 = None - frame_skip = cfg.env.frame_skip - ppo_epochs = cfg.loss.ppo_epochs - total_done = 0 - for data in collector: - - frames_in_batch = data.numel() - total_done += data.get(("next", "done")).sum() - collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) - - # Log end-of-episode accumulated rewards for training - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if logger is not None and len(episode_rewards) > 0: - logger.log_scalar( - "reward_training", episode_rewards.mean().item(), collected_frames - ) - - losses = TensorDict( - {}, batch_size=[ppo_epochs, -(frames_in_batch // -mini_batch_size)] - ) - for j in range(ppo_epochs): - # Compute GAE - with torch.no_grad(): - data = adv_module(data.to(model_device)).cpu() - - data_reshape = data.reshape(-1) - # Update the data buffer - data_buffer.extend(data_reshape) - - for i, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(model_device) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, i] = loss.detach() - - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=0.5 - ) - losses[j, i]["grad_norm"] = grad_norm - - optim.step() - if scheduler is not None: - scheduler.step() - optim.zero_grad() - - # Logging - if r0 is None: - r0 = data["next", "reward"].mean().item() - if l0 is None: - l0 = loss_sum.item() - pbar.set_description( - f"loss: {loss_sum.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['next', 'reward'].mean(): 4.4f} (init={r0: 4.4f})" - ) - if i + 1 != -(frames_in_batch // -mini_batch_size): - print( - f"Should have had {- (frames_in_batch // -mini_batch_size)} iters but had {i}." - ) - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) - if logger is not None: - for key, value in losses.items(): - logger.log_scalar(key, value.item(), collected_frames) - logger.log_scalar("total_done", total_done, collected_frames) - - collector.update_policy_weights_() - - # Test current policy - if ( - logger is not None - and (collected_frames - frames_in_batch) // record_interval - < collected_frames // record_interval - ): - - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - test_env.eval() - actor.eval() - # Generate a complete episode - td_test = test_env.rollout( - policy=actor, - max_steps=10_000_000, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - ).clone() - logger.log_scalar( - "reward_testing", - td_test["next", "reward"].sum().item(), - collected_frames, - ) - actor.train() - del td_test - - -if __name__ == "__main__": - main() diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py new file mode 100644 index 00000000000..351d0af0ae8 --- /dev/null +++ b/examples/ppo/ppo_atari.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on Atari Environments. +""" + +import hydra + + +@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_atari import eval_model, make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, logger_name="ppo", experiment_name=exp_name + ) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) + + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + # Apply episodic end of life + data["done"].copy_(data["end_of_life"]) + data["next", "done"].copy_(data["next", "end_of_life"]) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + training_start = time.time() + for j in range(cfg.loss.ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + if cfg.loss.anneal_clip_epsilon: + loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/ppo_atari_pong.png b/examples/ppo/ppo_atari_pong.png deleted file mode 100644 index 639545f29e4..00000000000 Binary files a/examples/ppo/ppo_atari_pong.png and /dev/null differ diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py new file mode 100644 index 00000000000..f081d8e69ee --- /dev/null +++ b/examples/ppo/ppo_mujoco.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on MuJoCo Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.advantages import GAE + from torchrl.record.loggers import generate_exp_name, get_logger + from utils_mujoco import eval_model, make_env, make_ppo_models + + # Define paper hyperparameters + device = "cpu" if not torch.cuda.device_count() else "cuda" + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size + total_network_updates = ( + (cfg.collector.total_frames // cfg.collector.frames_per_batch) + * cfg.loss.ppo_epochs + * num_mini_batches + ) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device), + sampler=sampler, + batch_size=cfg.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger( + cfg.logger.backend, logger_name="ppo", experiment_name=exp_name + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + training_start = time.time() + for j in range(cfg.loss.ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + for group in critic_optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + if cfg.loss.anneal_clip_epsilon: + loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < ( + i * frames_in_batch + ) // cfg.logger.test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/ppo/ppo_mujoco_halfcheetah.png b/examples/ppo/ppo_mujoco_halfcheetah.png deleted file mode 100644 index f168a5d40f3..00000000000 Binary files a/examples/ppo/ppo_mujoco_halfcheetah.png and /dev/null differ diff --git a/examples/ppo/training_curves.md b/examples/ppo/training_curves.md deleted file mode 100644 index d9f99eadb42..00000000000 --- a/examples/ppo/training_curves.md +++ /dev/null @@ -1,13 +0,0 @@ -# PPO Example Results - -## Atari Pong Environment - -We tested the Proximal Policy Optimization (PPO) algorithm on the Atari Pong environment. The hyperparameters used for the training are specified in the config.yaml file and are the same as those used in the original PPO paper (https://arxiv.org/abs/1707.06347). - -![ppo_atari_pong.png](ppo_atari_pong.png) - -## MuJoCo HalfCheetah Environment - -Additionally, we also tested the PPO algorithm on the MuJoCo HalfCheetah environment. The hyperparameters used for the training are specified in the config_example2.yaml file and are also the same as those used in the original PPO paper. However, this implementation uses a shared policy-value architecture. - -![ppo_mujoco_halfcheetah.png](ppo_mujoco_halfcheetah.png) diff --git a/examples/ppo/utils.py b/examples/ppo/utils.py deleted file mode 100644 index 977d8e20b64..00000000000 --- a/examples/ppo/utils.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch.nn -import torch.optim -from tensordict.nn import NormalParamExtractor, TensorDictModule - -from torchrl.collectors import SyncDataCollector -from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - -from torchrl.data.tensor_specs import DiscreteBox -from torchrl.envs import ( - CatFrames, - CatTensors, - DoubleToFloat, - EnvCreator, - ExplorationType, - GrayScale, - NoopResetEnv, - ObservationNorm, - ParallelEnv, - Resize, - RewardScaling, - RewardSum, - StepCounter, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.modules import ( - ActorValueOperator, - ConvNet, - MLP, - OneHotCategorical, - ProbabilisticActor, - TanhNormal, - ValueOperator, -) -from torchrl.objectives import ClipPPOLoss -from torchrl.objectives.value.advantages import GAE -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.envs import LIBS - - -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - - -# ==================================================================== -# Environment utils -# ----------------- - - -def make_base_env(env_cfg, from_pixels=None): - env_library = LIBS[env_cfg.env_library] - env_kwargs = { - "env_name": env_cfg.env_name, - "frame_skip": env_cfg.frame_skip, - "from_pixels": env_cfg.from_pixels - if from_pixels is None - else from_pixels, # for rendering - "pixels_only": False, - "device": env_cfg.device, - } - if env_library is DMControlEnv: - env_task = env_cfg.env_task - env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) - return env - - -def make_transformed_env(base_env, env_cfg): - if env_cfg.noop > 1: - base_env = TransformedEnv(env=base_env, transform=NoopResetEnv(env_cfg.noop)) - from_pixels = env_cfg.from_pixels - if from_pixels: - return make_transformed_env_pixels(base_env, env_cfg) - else: - return make_transformed_env_states(base_env, env_cfg) - - -def make_transformed_env_pixels(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - env.append_transform(RewardScaling(0.0, reward_scaling)) - - env.append_transform(ToTensorImage()) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - - obs_norm = ObservationNorm(in_keys=["pixels"], standard_normal=True) - env.append_transform(obs_norm) - - env.append_transform(DoubleToFloat()) - return env - - -def make_transformed_env_states(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - - env.append_transform(RewardScaling(0.0, reward_scaling)) - - # we concatenate all the state vectors - # even if there is a single tensor, it'll be renamed in "observation_vector" - selected_keys = [ - key for key in env.observation_spec.keys(True, True) if key != "pixels" - ] - out_key = "observation_vector" - env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter()) - # obs_norm = ObservationNorm(in_keys=[out_key]) - # env.append_transform(obs_norm) - - env.append_transform(DoubleToFloat()) - return env - - -def make_parallel_env(env_cfg, state_dict): - num_envs = env_cfg.num_envs - env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg - ) - init_stats(env, 3, env_cfg.from_pixels) - env.load_state_dict(state_dict, strict=False) - return env - - -def get_stats(env_cfg): - env = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(env, env_cfg.n_samples_stats, env_cfg.from_pixels) - state_dict = env.state_dict() - for key in list(state_dict.keys()): - if key.endswith("loc") or key.endswith("scale"): - continue - del state_dict[key] - return state_dict - - -def init_stats(env, n_samples_stats, from_pixels): - for t in env.transform: - if isinstance(t, ObservationNorm): - if from_pixels: - t.init_stats( - n_samples_stats, - cat_dim=-4, - reduce_dim=tuple( - -i for i in range(1, len(t.parent.batch_size) + 5) - ), - keep_dims=(-1, -2, -3), - ) - else: - t.init_stats(n_samples_stats) - - -def make_test_env(env_cfg, state_dict): - env_cfg.num_envs = 1 - env = make_parallel_env(env_cfg, state_dict=state_dict) - return env - - -# ==================================================================== -# Collector and replay buffer -# --------------------------- - - -def make_collector(cfg, policy): - env_cfg = cfg.env - collector_cfg = cfg.collector - collector_class = SyncDataCollector - state_dict = get_stats(env_cfg) - collector = collector_class( - make_parallel_env(env_cfg, state_dict=state_dict), - policy, - frames_per_batch=collector_cfg.frames_per_batch, - total_frames=collector_cfg.total_frames, - device=collector_cfg.collector_device, - storing_device="cpu", - max_frames_per_traj=collector_cfg.max_frames_per_traj, - ) - return collector, state_dict - - -def make_data_buffer(cfg): - cfg_collector = cfg.collector - cfg_loss = cfg.loss - sampler = SamplerWithoutReplacement() - return TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg_collector.frames_per_batch), - sampler=sampler, - batch_size=cfg_loss.mini_batch_size, - ) - - -# ==================================================================== -# Model -# ----- -# -# We give one version of the model for learning from pixels, and one for state. -# TorchRL comes in handy at this point, as the high-level interactions with -# these models is unchanged, regardless of the modality. - - -def make_ppo_models(cfg): - - env_cfg = cfg.env - from_pixels = env_cfg.from_pixels - proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(proof_environment, 3, env_cfg.from_pixels) - - if not from_pixels: - # we must initialize the observation norm transform - # init_stats( - # proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels - # ) - common_module, policy_module, value_module = make_ppo_modules_state( - proof_environment - ) - else: - common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment - ) - - # Wrap modules in a single ActorCritic operator - actor_critic = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_module, - value_operator=value_module, - ).to(cfg.optim.device) - - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - - actor = actor_critic.get_policy_operator() - critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() - - return actor, critic, critic_head - - -def make_ppo_modules_state(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["observation_vector"].shape - - # Define distribution class and kwargs - continuous_actions = False - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - continuous_actions = True - num_outputs = proof_environment.action_spec.shape[-1] * 2 - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.low, - "max": proof_environment.action_spec.space.high, - "tanh_loc": False, - } - - # Define input keys - in_keys = ["observation_vector"] - shared_features_size = 256 - - # Define a shared Module and TensorDictModule - common_mlp = MLP( - in_features=input_shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=shared_features_size, - num_cells=[64, 64], - ) - common_module = TensorDictModule( - module=common_mlp, - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=shared_features_size, out_features=num_outputs, num_cells=[] - ) - if continuous_actions: - policy_net = torch.nn.Sequential( - policy_net, NormalParamExtractor(scale_lb=1e-2) - ) - - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["loc", "scale"] if continuous_actions else ["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["loc", "scale"] if continuous_actions else ["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP(in_features=shared_features_size, out_features=1, num_cells=[]) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -def make_ppo_modules_pixels(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["pixels"].shape - - # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.low, - "max": proof_environment.action_spec.space.high, - } - - # Define input keys - in_keys = ["pixels"] - - # Define a shared Module and TensorDictModule (CNN + MLP) - common_cnn = ConvNet( - activation_class=torch.nn.ReLU, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones(input_shape)) - common_mlp = MLP( - in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=512, - num_cells=[], - ) - common_mlp_output = common_mlp(common_cnn_output) - - # Define shared net as TensorDictModule - common_module = TensorDictModule( - module=torch.nn.Sequential(common_cnn, common_mlp), - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=common_mlp_output.shape[-1], - out_features=num_outputs, - activation_class=torch.nn.ReLU, - num_cells=[256], - ) - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - # safe=True, - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP( - activation_class=torch.nn.ReLU, - in_features=common_mlp_output.shape[-1], - out_features=1, - num_cells=[256], - ) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -# ==================================================================== -# PPO Loss -# --------- - - -def make_advantage_module(loss_cfg, value_network): - advantage_module = GAE( - gamma=loss_cfg.gamma, - lmbda=loss_cfg.gae_lamdda, - value_network=value_network, - average_gae=True, - ) - return advantage_module - - -def make_loss(loss_cfg, actor_network, value_network, value_head): - advantage_module = make_advantage_module(loss_cfg, value_network) - loss_module = ClipPPOLoss( - actor=actor_network, - critic=value_head, - clip_epsilon=loss_cfg.clip_epsilon, - loss_critic_type=loss_cfg.loss_critic_type, - entropy_coef=loss_cfg.entropy_coef, - critic_coef=loss_cfg.critic_coef, - normalize_advantage=True, - ) - return loss_module, advantage_module - - -def make_optim(optim_cfg, loss_module): - optim = torch.optim.Adam( - loss_module.parameters(), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, - ) - return optim - - -# ==================================================================== -# Logging and recording -# --------------------- - - -def make_logger(logger_cfg): - exp_name = generate_exp_name("PPO", logger_cfg.exp_name) - logger_cfg.exp_name = exp_name - logger = get_logger(logger_cfg.backend, logger_name="ppo", experiment_name=exp_name) - return logger diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py new file mode 100644 index 00000000000..54d920e27b1 --- /dev/null +++ b/examples/ppo/utils_atari.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gymnasium as gym +import numpy as np +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + default_info_dict_reader, + DoubleToFloat, + EnvCreator, + ExplorationType, + GrayScale, + NoopResetEnv, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. It helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + + def step(self, action): + obs, rew, done, truncated, info = self.env.step(action) + lives = self.env.unwrapped.ale.lives() + info["end_of_life"] = False + if (lives < self.lives) or done: + info["end_of_life"] = True + self.lives = lives + return obs, rew, done, truncated, info + + def reset(self, **kwargs): + reset_data = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() + return reset_data + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = gym.make(env_name) + if not is_test: + env = EpisodicLifeEnv(env) + env = GymWrapper( + env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + reader = default_info_dict_reader(["end_of_life"]) + env.set_info_dict_reader(reader) + return env + + +def make_parallel_env(env_name, num_envs, device, is_test=False): + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, 1, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + + del proof_environment + + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/examples/ppo/utils_mujoco.py b/examples/ppo/utils_mujoco.py new file mode 100644 index 00000000000..cdc681da522 --- /dev/null +++ b/examples/ppo/utils_mujoco.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + StepCounter, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu"): + env = GymEnv(env_name, device=device) + env = TransformedEnv(env) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + env.append_transform(VecNorm(in_keys=["observation"])) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean()