Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
181 commits
Select commit Hold shift + click to select a range
9e5f303
init
vmoens Sep 15, 2023
03b6089
fix
vmoens Sep 15, 2023
512b596
amend
vmoens Sep 17, 2023
3fa8ac0
amend
vmoens Sep 17, 2023
162aa6e
amend
vmoens Sep 17, 2023
2c5cddc
amend
vmoens Sep 17, 2023
c703a02
amend
vmoens Sep 17, 2023
2b78f49
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 17, 2023
822f42a
amend
vmoens Sep 18, 2023
db8e0c1
amend
vmoens Sep 18, 2023
9c7e1a5
amend
vmoens Sep 18, 2023
99bcc4a
amend
vmoens Sep 18, 2023
9b0069e
lint
vmoens Sep 18, 2023
ac43a7e
fix step counter
vmoens Sep 18, 2023
a822407
amend
vmoens Sep 18, 2023
6cec6e1
amend
vmoens Sep 18, 2023
0612e09
amend
vmoens Sep 18, 2023
02902db
Update torchrl/envs/gym_like.py
Sep 19, 2023
7e22c55
Update torchrl/envs/gym_like.py
Sep 19, 2023
53927bc
Update torchrl/envs/gym_like.py
Sep 19, 2023
43ff66a
amend
vmoens Sep 19, 2023
77dcd09
rollout
vmoens Sep 19, 2023
f404245
fix
vmoens Sep 19, 2023
a79b57c
fix
vmoens Sep 19, 2023
149781f
fix
vmoens Sep 19, 2023
71a05b1
fix
vmoens Sep 19, 2023
87bad08
remove prints
vmoens Sep 19, 2023
92c3814
amend
vmoens Sep 19, 2023
42d4c40
amend
vmoens Sep 19, 2023
aec627c
amend
vmoens Sep 19, 2023
b2303bd
amend
vmoens Sep 19, 2023
f6a497b
amend
vmoens Sep 19, 2023
aac630f
amend
vmoens Sep 19, 2023
606ee3a
lint and fixes
vmoens Sep 19, 2023
6ba0d38
amend
vmoens Sep 19, 2023
76b3f0c
amend
vmoens Sep 19, 2023
035c274
amend
vmoens Sep 19, 2023
8bd932f
amend
vmoens Sep 19, 2023
c789e50
amend
vmoens Sep 19, 2023
3e93f13
amend
vmoens Sep 19, 2023
cba97b1
amend
vmoens Sep 19, 2023
7ec7c78
amend
vmoens Sep 19, 2023
dd4c45e
amend
vmoens Sep 19, 2023
0ea0716
fix robohive
vmoens Sep 19, 2023
16d688e
amend
vmoens Sep 20, 2023
268dbd7
amend
vmoens Sep 20, 2023
d77d1cd
amend
vmoens Sep 20, 2023
15bd9fa
amend
vmoens Sep 20, 2023
1b656f7
amend
vmoens Sep 20, 2023
2f13c95
amend
vmoens Sep 20, 2023
aa5de06
amend
vmoens Sep 20, 2023
284262f
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 20, 2023
afcc527
amend
vmoens Sep 20, 2023
9eae41d
amend
vmoens Sep 20, 2023
9210f3b
amend
vmoens Sep 20, 2023
b31b2f0
amend
vmoens Sep 21, 2023
4e8acc0
init
vmoens Sep 21, 2023
c8579f9
init
vmoens Sep 21, 2023
d95989c
Merge branch 'fix_dreamer_tests' into threads_mp
vmoens Sep 21, 2023
e22c318
Merge branch 'terminal_truncated' into myo_threaded
vmoens Sep 21, 2023
a5e9ce3
prints
vmoens Sep 21, 2023
b6e83d5
amend
vmoens Sep 21, 2023
acf89e6
amend
vmoens Sep 21, 2023
16fba2e
amend
vmoens Sep 21, 2023
697c523
amend
vmoens Sep 21, 2023
369492d
fix
vmoens Sep 21, 2023
c50263c
Merge branch 'main' into terminal_truncated
vmoens Sep 21, 2023
de82499
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
e14dcd1
fix
vmoens Sep 21, 2023
4a8424c
amend
vmoens Sep 21, 2023
5056b9e
amend
vmoens Sep 21, 2023
ce26e13
amend
vmoens Sep 21, 2023
5a95850
amend
vmoens Sep 21, 2023
5e38d70
Update torchrl/collectors/collectors.py
Sep 21, 2023
9eb1c98
amend
vmoens Sep 21, 2023
bccbf67
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
a285780
amend
vmoens Sep 21, 2023
40a8e83
Merge branch 'terminal_truncated' into myo_threaded
vmoens Sep 21, 2023
d8f9505
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
e1eba40
Merge remote-tracking branch 'origin/main' into threads_mp
vmoens Sep 21, 2023
9f58a8d
lint
vmoens Sep 21, 2023
1c4f35f
amend
vmoens Sep 21, 2023
93bd2e6
Merge branch 'threads_mp' into terminal_truncated
vmoens Sep 21, 2023
f6e09e3
amend
vmoens Sep 21, 2023
2cd07c1
amend
vmoens Sep 22, 2023
acf6118
amend
vmoens Sep 22, 2023
bb52ce1
tests
vmoens Sep 22, 2023
0d0bc3c
Merge branch 'main' into terminal_truncated
vmoens Sep 22, 2023
3ef139b
amend
vmoens Sep 22, 2023
0d3ba02
amend
vmoens Sep 22, 2023
0b32209
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 22, 2023
0638e13
amend
vmoens Sep 22, 2023
a8ed5e5
amend
vmoens Sep 22, 2023
54db14e
amend
vmoens Sep 22, 2023
0122fac
amend
vmoens Sep 23, 2023
b18e6e7
amend
vmoens Sep 23, 2023
2f99f16
amend
vmoens Sep 23, 2023
6aa0c9e
amend
vmoens Sep 23, 2023
2b42070
amend
vmoens Sep 23, 2023
7279b12
amend
vmoens Sep 23, 2023
f5ab14d
amend
vmoens Sep 23, 2023
4a9b6b9
amend
vmoens Sep 23, 2023
696324b
amend
vmoens Sep 23, 2023
8890911
Update docs/source/reference/envs.rst
Sep 24, 2023
e24c2f3
add doc
vmoens Sep 24, 2023
c029f12
amend
vmoens Sep 24, 2023
b37129d
amend
vmoens Sep 24, 2023
9afc783
amend
vmoens Sep 24, 2023
f65622a
amend
vmoens Sep 24, 2023
989eecf
fix VIP
vmoens Sep 24, 2023
77559e0
lint
vmoens Sep 24, 2023
117e41e
osx_skips
vmoens Sep 24, 2023
19fdc33
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 24, 2023
8326674
amend
vmoens Sep 24, 2023
90da3fd
refactor tests
vmoens Sep 25, 2023
f4ddb92
Refactoring: terminated, truncated, done
vmoens Sep 25, 2023
7988118
amend
vmoens Sep 25, 2023
85d5035
let _step return partial done in batched envs
vmoens Sep 25, 2023
b4fa338
fix mocking classes
vmoens Sep 25, 2023
648792d
more fixes
vmoens Sep 25, 2023
9777332
fix step count in equivalence test
vmoens Sep 25, 2023
9063b5b
fix transforms
vmoens Sep 25, 2023
f47e49d
fix transforms
vmoens Sep 25, 2023
c238e9b
fix transformed env
vmoens Sep 25, 2023
502a59b
amend
vmoens Sep 25, 2023
73af141
amend
vmoens Sep 25, 2023
983e246
amend
vmoens Sep 26, 2023
555bca9
amend
vmoens Sep 26, 2023
85f58de
remove calls to done_key
vmoens Sep 26, 2023
6fc39d0
fix step counter
vmoens Sep 26, 2023
175b701
vec envs
vmoens Sep 26, 2023
92e7826
vec envs
vmoens Sep 26, 2023
cd6eaea
amend
vmoens Sep 26, 2023
23e139b
amend
vmoens Sep 26, 2023
677c408
amend
vmoens Sep 26, 2023
3c79081
d4rl
vmoens Sep 26, 2023
54e75b0
d4rl unsqueeze
vmoens Sep 26, 2023
244e3d3
amend
vmoens Sep 26, 2023
09abb71
amend
vmoens Sep 26, 2023
384db24
minor
vmoens Sep 26, 2023
3693cac
amend
vmoens Sep 26, 2023
ca54133
amend
vmoens Sep 26, 2023
0fdc522
amend
vmoens Sep 26, 2023
fdad78f
test_terminated_or_truncated_spec
vmoens Sep 26, 2023
f454e11
more fixes
vmoens Sep 26, 2023
88cee59
--capture no
vmoens Sep 26, 2023
57ccb63
attempt to limit collector idle time
vmoens Sep 26, 2023
e5b0d23
lint
vmoens Sep 26, 2023
dfe726f
amend
vmoens Sep 26, 2023
4f6ce90
amend
vmoens Sep 26, 2023
b16b939
amend
vmoens Sep 26, 2023
daaaddd
amend
vmoens Sep 26, 2023
cd4811f
amend
vmoens Sep 26, 2023
c0f3137
fixes
vmoens Sep 26, 2023
8f9d8fe
fix r3m, vip and vc1
vmoens Sep 27, 2023
4fdf437
fix robohive, d4rl
vmoens Sep 27, 2023
4f44579
amend
vmoens Sep 27, 2023
7787b28
amend
vmoens Sep 27, 2023
5c964a9
amend
vmoens Sep 27, 2023
03001bc
amend
vmoens Sep 27, 2023
04bbee9
lint
vmoens Sep 27, 2023
6ac830e
adapt tests
vmoens Sep 27, 2023
6fbe8c0
amend
vmoens Sep 27, 2023
28fefb6
lint
vmoens Sep 27, 2023
8256e2f
fix gym 0.19
vmoens Sep 27, 2023
1619b09
missing deps
vmoens Sep 27, 2023
23926b2
fix gym truncated
vmoens Sep 27, 2023
e3b8253
fix gym truncated (bis)
vmoens Sep 27, 2023
2a4a1b6
amend
vmoens Sep 27, 2023
7f4c38b
amend
vmoens Sep 27, 2023
2e54626
Merge branch 'main' into terminal_truncated
vmoens Sep 27, 2023
977488e
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 27, 2023
4f932e0
amend
vmoens Sep 27, 2023
5e6a0f0
amend
vmoens Sep 27, 2023
411f097
lint
vmoens Sep 27, 2023
39ed4c3
final (?)
vmoens Sep 28, 2023
21ea856
addressing review
vmoens Sep 28, 2023
f0ee4dd
more fixes
vmoens Sep 28, 2023
7906387
amend
vmoens Sep 28, 2023
72c1240
cloning dones
vmoens Sep 28, 2023
2c7ffb0
amend
vmoens Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,16 @@ python -m torch.utils.collect_env
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch

export MAX_IDLE_COUNT=100

pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 --ignore test/test_rlhf.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 --ignore test/test_rlhf.py --ignore test/test_distributed.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py
fi

coverage combine
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_examples/scripts/run_local.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash

set -e
set -v

# Read script from line 29
filename=".github/unittest/linux_examples/scripts/run_test.sh"
Expand All @@ -12,7 +13,7 @@ script="set -e"$'\n'"$script"
script="${script//cuda:0/cpu}"

# Remove any instances of ".github/unittest/helpers/coverage_run_parallel.py"
script="${script//.circleci\/unittest\/helpers\/coverage_run_parallel.py}"
script="${script//.github\/unittest\/helpers\/coverage_run_parallel.py}"
script="${script//coverage combine}"
script="${script//coverage xml -i}"

Expand Down
10 changes: 6 additions & 4 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco
collector.total_frames=40 \
collector.frames_per_batch=20 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
loss.ppo_epochs=2 \
logger.backend= \
logger.test_interval=40
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
loss.mini_batch_size=20 \
loss.ppo_epochs=1 \
loss.ppo_epochs=2 \
logger.backend= \
logger.test_interval=40
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -126,6 +126,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
optimization.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
# logger.record_video=True \
# logger.record_frames=4 \
Expand Down Expand Up @@ -225,6 +226,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
network.device=cuda:0 \
optimization.batch_size=10 \
optimization.utd_ratio=1 \
replay_buffer.size=120 \
Expand Down
3 changes: 3 additions & 0 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
# handling https://github.com/openai/gym/issues/3202
pip3 install wheel==0.38.4
pip3 install gym==$GYM_VERSION
$DIR/run_test.sh

Expand All @@ -67,6 +69,7 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install wheel==0.38.4
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install ale-py==0.7
$DIR/run_test.sh
Expand Down
35 changes: 20 additions & 15 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Each env will have the following attributes:
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the reward spec.
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the done-flag spec.
the done-flag spec. See the section on trajectory termination below.
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`).
It is locked and should not be modified directly.
Expand Down Expand Up @@ -79,22 +79,25 @@ The following figure summarizes how a rollout is executed in torchrl.

In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method,
then populated with an action by the policy before being passed to the
:meth:`~.EnvBase.step` method which writes the observations, done flag and
:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and
reward under the ``"next"`` entry. The result of this call is stored for
delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp`
function.

.. note::

The Gym(nasium) API recently shifted to a splitting of the ``"done"`` state
into a ``terminated`` (the env is done and results should not be trusted)
and ``truncated`` (the maximum number of steps is reached) flags.
In TorchRL, ``"done"`` usually refers to ``"terminated"``. Truncation is
achieved via the :class:`~.StepCounter` transform class, and the output
key will be ``"truncated"`` if not chosen to be something else (e.g.
``StepCounter(max_steps=100, truncated_key="done")``).
TorchRL's collectors and rollout methods will be looking for one of these
keys when assessing if the env should be reset.
In general, all TorchRL environment have a ``"done"`` and ``"terminated"``
entry in their output tensordict. If they are not present by design,
the :class:`~.EnvBase` metaclass will ensure that every done or terminated
is flanked with its dual.
In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory
signals and should be interpreted as "the last step of a trajectory" or
equivalently "a signal indicating the need to reset".
If the environment provides it (eg, Gymnasium), the truncation entry is also
written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry.
If the environment carries a single value, it will interpreted as a ``"terminated"``
signal by default.
By default, TorchRL's collectors and rollout methods will be looking for the ``"done"``
entry to assess if the environment should be reset.

.. note::

Expand Down Expand Up @@ -172,12 +175,13 @@ It is also possible to reset some but not all of the environments:
:caption: Parallel environment reset

>>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4])
>>> env.reset(tensordict)
>>> env.reset(tensordict) # eliminates the "_reset" entry
TensorDict(
fields={
terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool),
done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
_reset: Tensor(torch.Size([4, 1]), dtype=torch.bool)},
truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool),
batch_size=torch.Size([4]),
device=None,
is_shared=True)
Expand Down Expand Up @@ -238,7 +242,7 @@ Some of the main differences between these paradigms include:

- **observation** can be per-agent and also have some shared components
- **reward** can be per-agent or shared
- **done** can be per-agent or shared
- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared.

TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier.
In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict.
Expand Down Expand Up @@ -586,6 +590,7 @@ Helpers
exploration_type
check_env_specs
make_composite_from_td
terminated_or_truncated

Domain-specific
---------------
Expand Down
4 changes: 0 additions & 4 deletions examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

# Apply episodic end of life
data["done"].copy_(data["end_of_life"])
data["next", "done"].copy_(data["next", "end_of_life"])

losses = TensorDict({}, batch_size=[num_mini_batches])
training_start = time.time()

Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):
batch_size=rb_cfg.batch_size,
sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False),
transform=transforms,
use_timeout_as_done=True,
use_truncated_as_done=True,
)
full_data = data._get_dataset_from_env(rb_cfg.dataset, {})
loc = full_data["observation"].mean(axis=0).float()
Expand Down
44 changes: 27 additions & 17 deletions examples/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def main(cfg: "DictConfig"): # noqa: F821
normalize_advantage=True,
)

# use end-of-life as done key
loss_module.set_keys(done="eol")

# Create optimizer
optim = torch.optim.Adam(
loss_module.parameters(),
Expand Down Expand Up @@ -109,6 +112,18 @@ def main(cfg: "DictConfig"): # noqa: F821
)

sampling_start = time.time()

# extract cfg variables
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.optim.anneal_lr
cfg_optim_lr = cfg.optim.lr
cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
cfg_optim_max_grad_norm = cfg.optim.max_grad_norm
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):

log_info = {}
Expand All @@ -120,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
episode_length = data["next", "step_count"][data["next", "stop"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
Expand All @@ -129,13 +144,8 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

# Apply episodic end of life
data["done"].copy_(data["end_of_life"])
data["next", "done"].copy_(data["next", "end_of_life"])

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
training_start = time.time()
for j in range(cfg.loss.ppo_epochs):
for j in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
Expand All @@ -149,12 +159,12 @@ def main(cfg: "DictConfig"): # noqa: F821

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg.optim.anneal_lr:
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
if cfg.loss.anneal_clip_epsilon:
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1

# Get a data batch
Expand All @@ -172,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm
)

# Update the networks
Expand All @@ -181,15 +191,15 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get training losses and times
training_time = time.time() - training_start
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses_mean.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg.optim.lr,
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/clip_epsilon": alpha * cfg.loss.clip_epsilon,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
}
)

Expand All @@ -201,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
Expand Down
38 changes: 24 additions & 14 deletions examples/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

sampling_start = time.time()

# extract cfg variables
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.optim.anneal_lr
cfg_optim_lr = cfg.optim.lr
cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):

log_info = {}
Expand All @@ -120,9 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
training_start = time.time()
for j in range(cfg.loss.ppo_epochs):
for j in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
Expand All @@ -136,14 +146,14 @@ def main(cfg: "DictConfig"): # noqa: F821

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg.optim.anneal_lr:
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
if cfg.loss.anneal_clip_epsilon:
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1

# Forward pass PPO loss
Expand All @@ -166,27 +176,27 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get training losses and times
training_time = time.time() - training_start
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses_mean.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg.optim.lr,
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/clip_epsilon": alpha * cfg.loss.clip_epsilon,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
}
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg.logger.test_interval:
) // cfg_logger_test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
Expand Down
Loading