Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/unittest/linux_examples/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ dependencies:
- coverage
- vmas
- transformers
- gym[atari]
- gym[accept-rom-license]
33 changes: 12 additions & 21 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
env.num_envs=1 \
env.device=cuda:0 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
optim.device=cuda:0 \
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
collector.frames_per_batch=20 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
logger.test_interval=40
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
loss.mini_batch_size=20 \
loss.ppo_epochs=1 \
logger.backend= \
logger.test_interval=40
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -208,18 +211,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
record_video=True \
record_frames=4 \
buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
env.num_envs=1 \
env.device=cuda:0 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
optim.device=cuda:0 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
total_frames=48 \
init_random_frames=10 \
Expand Down
29 changes: 29 additions & 0 deletions examples/ppo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## Reproducing Proximal Policy Optimization (PPO) Algorithm Results

This repository contains scripts that enable training agents using the Proximal Policy Optimization (PPO) Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Schulman et al. (2017) to implement the PPO algorithm but introduce the improvement of computing the Generalised Advantage Estimator (GAE) at every epoch.


## Examples Structure

Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:

1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. ppo_atari.py).

2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).

3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).


## Running the Examples

You can execute the PPO algorithm on Atari environments by running the following command:

```bash
python ppo_atari.py
```

You can execute the PPO algorithm on MuJoCo environments by running the following command:

```bash
python ppo_mujoco.py
```
46 changes: 0 additions & 46 deletions examples/ppo/config.yaml

This file was deleted.

36 changes: 36 additions & 0 deletions examples/ppo/config_atari.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Environment
env:
env_name: PongNoFrameskip-v4
num_envs: 8

# collector
collector:
frames_per_batch: 4096
total_frames: 40_000_000

# logger
logger:
backend: wandb
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3

# Optim
optim:
lr: 2.5e-4
eps: 1.0e-6
weight_decay: 0.0
max_grad_norm: 0.5
anneal_lr: True

# loss
loss:
gamma: 0.99
mini_batch_size: 1024
ppo_epochs: 3
gae_lambda: 0.95
clip_epsilon: 0.1
anneal_clip_epsilon: True
critic_coef: 1.0
entropy_coef: 0.01
loss_critic_type: l2
43 changes: 0 additions & 43 deletions examples/ppo/config_example2.yaml

This file was deleted.

33 changes: 33 additions & 0 deletions examples/ppo/config_mujoco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# task and env
env:
env_name: HalfCheetah-v3

# collector
collector:
frames_per_batch: 2048
total_frames: 1_000_000

# logger
logger:
backend: wandb
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5

# Optim
optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: False

# loss
loss:
gamma: 0.99
mini_batch_size: 64
ppo_epochs: 10
gae_lambda: 0.95
clip_epsilon: 0.2
anneal_clip_epsilon: False
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2
Loading