Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6339a07
update executable
BY571 Sep 6, 2023
9e890b3
fix objective
BY571 Sep 7, 2023
117c477
fix objective
BY571 Sep 7, 2023
d2b3ad4
Update initial frames and general structure
BY571 Sep 12, 2023
9c6c358
fixes
BY571 Sep 12, 2023
1adbff5
Merge branch 'main' into td3_benchmark
BY571 Sep 12, 2023
2422ef8
naming fix
BY571 Sep 12, 2023
0e67de2
single step td3
BY571 Sep 13, 2023
1fc0847
small fixes
BY571 Sep 14, 2023
7a02b83
fix
BY571 Sep 14, 2023
243d712
add update counter
BY571 Sep 14, 2023
af31bd9
naming fixes
BY571 Sep 14, 2023
1122808
update logging and small fixes
BY571 Sep 15, 2023
b4df32b
no eps
BY571 Sep 18, 2023
13f367a
update tests
BY571 Sep 19, 2023
72ddf7e
update objective
BY571 Sep 20, 2023
c830891
set gym backend
BY571 Sep 20, 2023
1a2f08e
Merge branch 'main' into td3_benchmark
vmoens Sep 21, 2023
4cdbb3b
update tests
BY571 Sep 21, 2023
76dcdeb
update fix max episode steps
BY571 Sep 22, 2023
68d4c26
Merge branch 'main' into td3_benchmark
BY571 Sep 26, 2023
ec8b089
fix
BY571 Sep 27, 2023
bcc3bc6
fix
BY571 Sep 27, 2023
42748e0
amend
vmoens Sep 28, 2023
0052cd9
Merge remote-tracking branch 'BY571/td3_benchmark' into td3_benchmark
vmoens Sep 28, 2023
e2c28c8
amend
vmoens Sep 28, 2023
bb496ef
update scratch_dir, frame skip, config
BY571 Sep 28, 2023
9b4704b
Merge branch 'main' into td3_benchmark
BY571 Oct 2, 2023
e622bf7
merge main
BY571 Oct 2, 2023
57bc54a
merge main
BY571 Oct 2, 2023
29977df
step counter
BY571 Oct 2, 2023
854e2a2
merge main
BY571 Oct 3, 2023
619f2ea
small fixes
BY571 Oct 3, 2023
8d36787
solve logger issue
vmoens Oct 3, 2023
a24ab8d
reset notensordict test
vmoens Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions examples/td3/config.yaml
Original file line number Diff line number Diff line change
@@ -1,47 +1,49 @@
# Environment
# task and env
env:
name: HalfCheetah-v3
task: ""
exp_name: "HalfCheetah-TD3"
exp_name: "HalfCheetah-TD3-ICLR"
library: gym
frame_skip: 1
seed: 42

# Collection
# collector
collector:
total_frames: 1000000
init_random_frames: 10000
total_frames: 3000000
init_random_frames: 25_000
init_env_steps: 1000
frames_per_batch: 1000
frames_per_batch: 1
max_frames_per_traj: 1000
async_collection: 1
collector_device: cpu
env_per_collector: 1
num_workers: 1

# Replay Buffer
# replay buffer
replay_buffer:
prb: 0 # use prioritized experience replay
size: 1000000

# Optimization
optimization:
# optim
optim:
utd_ratio: 1.0
gamma: 0.99
loss_function: l2
lr: 3e-4
weight_decay: 2e-4
lr: 3.0e-4
weight_decay: 0.0
batch_size: 256
target_update_polyak: 0.995
policy_update_delay: 2
policy_noise: 0.2
noise_clip: 0.5

# Network
# network
network:
hidden_sizes: [256, 256]
activation: relu
device: "cuda:0"

# Logging
# logging
logger:
backend: wandb
mode: online
Expand Down
120 changes: 69 additions & 51 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
The helper functions are coded in the utils.py associated with this script.
"""

import hydra
import time

import hydra
import numpy as np
import torch
import torch.cuda
Expand All @@ -35,6 +36,7 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("TD3", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
Expand All @@ -45,140 +47,156 @@ def main(cfg: "DictConfig"): # noqa: F821
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create Environments
# Create environments
train_env, eval_env = make_environment(cfg)

# Create Agent
# Create agent
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)

# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Make Off-Policy Collector
# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)

# Make Replay Buffer
# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optimization.batch_size,
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
device=device,
)

# Make Optimizers
# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)

rewards = []
rewards_eval = []

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
r0 = None
q_loss = None

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optimization.utd_ratio
* cfg.optim.utd_ratio
)
delayed_updates = cfg.optimization.policy_update_delay
delayed_updates = cfg.optim.policy_update_delay
prb = cfg.replay_buffer.prb
env_per_collector = cfg.collector.env_per_collector
eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip
eval_iter = cfg.logger.eval_iter
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
update_counter = 0

for i, tensordict in enumerate(collector):
sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy.step(tensordict.numel())
# update weights of the inference policy

# Update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for j in range(num_updates):
# sample from replay buffer
for _ in range(num_updates):
update_counter += 1
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]

# Update critic
optimizer_critic.zero_grad()
update_actor = j % delayed_updates == 0
update_actor = update_counter % delayed_updates == 0
q_loss.backward(retain_graph=update_actor)
optimizer_critic.step()
q_losses.append(q_loss.item())

# Update actor
if update_actor:
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

actor_losses.append(actor_loss.item())

# update qnet_target params
# Update target params
target_net_updater.step()

# update priority
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append(
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
}
if q_loss is not None:
train_log.update(
{
"actor_loss": np.mean(actor_losses),
"q_loss": np.mean(q_losses),
}
training_time = time.time() - training_start
episode_rewards = tensordict["next", "episode_reward"][
tensordict["next", "done"]
]

# Logging
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
logger.log_scalar(
"train/reward", episode_rewards.mean().item(), collected_frames
)
logger.log_scalar(
"train/episode_length",
episode_length.sum().item() / len(episode_length),
collected_frames,
)
if logger is not None:
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

if collected_frames >= init_random_frames:
logger.log_scalar("train/q_loss", np.mean(q_losses), step=collected_frames)
if update_actor:
logger.log_scalar(
"train/a_loss", np.mean(actor_losses), step=collected_frames
)
logger.log_scalar("train/sampling_time", sampling_time, collected_frames)
logger.log_scalar("train/training_time", training_time, collected_frames)

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
if logger is not None:
logger.log_scalar(
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
)
if len(rewards_eval):
pbar.set_description(
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
)
logger.log_scalar("eval/reward", eval_reward, step=collected_frames)
logger.log_scalar("eval/time", eval_time, step=collected_frames)

sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
22 changes: 13 additions & 9 deletions examples/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EnvCreator,
InitTracker,
ParallelEnv,
RewardSum,
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -43,7 +44,8 @@ def apply_env_transforms(env, reward_scaling=1.0):
Compose(
InitTracker(),
RewardScaling(loc=0.0, scale=reward_scaling),
DoubleToFloat(),
DoubleToFloat("observation"),
RewardSum(),
),
)
return transformed_env
Expand Down Expand Up @@ -79,6 +81,7 @@ def make_collector(cfg, train_env, actor_model_explore):
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
Expand Down Expand Up @@ -222,17 +225,18 @@ def make_loss_module(cfg, model):
actor_network=model[0],
qvalue_network=model[1],
num_qvalue_nets=2,
loss_function=cfg.optimization.loss_function,
loss_function=cfg.optim.loss_function,
delay_actor=True,
delay_qvalue=True,
gamma=cfg.optim.gamma,
action_spec=model[0][1].spec,
policy_noise=cfg.optim.policy_noise,
noise_clip=cfg.optim.noise_clip,
)
loss_module.make_value_estimator(gamma=cfg.optimization.gamma)
loss_module.make_value_estimator(gamma=cfg.optim.gamma)

# Define Target Network Updater
target_net_updater = SoftUpdate(
loss_module, eps=cfg.optimization.target_update_polyak
)
target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
return loss_module, target_net_updater


Expand All @@ -241,11 +245,11 @@ def make_optimizer(cfg, loss_module):
actor_params = list(loss_module.actor_network_params.flatten_keys().values())

optimizer_actor = optim.Adam(
actor_params, lr=cfg.optimization.lr, weight_decay=cfg.optimization.weight_decay
actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay
)
optimizer_critic = optim.Adam(
critic_params,
lr=cfg.optimization.lr,
weight_decay=cfg.optimization.weight_decay,
lr=cfg.optim.lr,
weight_decay=cfg.optim.weight_decay,
)
return optimizer_actor, optimizer_critic
Loading