From 6339a07c016a9e611c6f592c8aff49de9a848d19 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 6 Sep 2023 14:26:00 +0200 Subject: [PATCH 01/28] update executable --- examples/td3/config.yaml | 10 +++++----- examples/td3/utils.py | 2 +- torchrl/objectives/td3.py | 6 ++++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 35a2d9f8b2f..85cd893aef3 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -2,14 +2,14 @@ env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-TD3" + exp_name: "HalfCheetah-TD3-ICLR" library: gym frame_skip: 1 seed: 42 # Collection collector: - total_frames: 1000000 + total_frames: 3000000 init_random_frames: 10000 init_env_steps: 1000 frames_per_batch: 1000 @@ -29,9 +29,9 @@ optimization: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 3e-4 - weight_decay: 2e-4 - batch_size: 256 + lr: 1e-3 + weight_decay: 0.0 + batch_size: 100 target_update_polyak: 0.995 policy_update_delay: 2 diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 9a8c5809f75..f634597425f 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -43,7 +43,7 @@ def apply_env_transforms(env, reward_scaling=1.0): Compose( InitTracker(), RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat(), + DoubleToFloat("observation"), ), ) return transformed_env diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 62f0e793f29..d37373d730b 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -336,13 +336,13 @@ def _cached_detach_qvalue_network_params(self): def _cached_stack_actor_params(self): return torch.stack( [self.actor_network_params, self.target_actor_network_params], 0 - ) + ).to_tensordict() @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys tensordict_save = tensordict - tensordict = tensordict.clone(False) + tensordict = tensordict.clone(False).to_tensordict() act = tensordict.get(self.tensor_keys.action) action_shape = act.shape action_device = act.device @@ -365,12 +365,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_actor, self._cached_stack_actor_params, ) + # add noise to target policy actor_output_td1 = actor_output_td[1] next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp( self.min_action, self.max_action ) actor_output_td1.set(self.tensor_keys.action, next_action) + actor_output_td = torch.stack([actor_output_td[0], actor_output_td1], 0) tensordict_actor.set( self.tensor_keys.action, actor_output_td.get(self.tensor_keys.action), From 9e890b3154861d20fa6a060ab8f491f16054fe9c Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 7 Sep 2023 14:42:43 +0200 Subject: [PATCH 02/28] fix objective --- examples/td3/config.yaml | 6 ++++-- examples/td3/utils.py | 2 ++ torchrl/objectives/td3.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 85cd893aef3..d46933e539d 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -29,11 +29,13 @@ optimization: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 1e-3 + lr: 3.0e-4 weight_decay: 0.0 - batch_size: 100 + batch_size: 256 target_update_polyak: 0.995 policy_update_delay: 2 + policy_noise: 0.2 + noise_clip: 0.5 # Network network: diff --git a/examples/td3/utils.py b/examples/td3/utils.py index f634597425f..1479ed2f670 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -226,6 +226,8 @@ def make_loss_module(cfg, model): delay_actor=True, delay_qvalue=True, action_spec=model[0][1].spec, + policy_noise=cfg.optimization.policy_noise, + policy_noise_clip=cfg.optimization.policy_noise_clip, ) loss_module.make_value_estimator(gamma=cfg.optimization.gamma) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index d37373d730b..e03cd5a1ade 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -336,7 +336,7 @@ def _cached_detach_qvalue_network_params(self): def _cached_stack_actor_params(self): return torch.stack( [self.actor_network_params, self.target_actor_network_params], 0 - ).to_tensordict() + ) @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -363,14 +363,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # DO NOT call contiguous bc we'll update the tds later actor_output_td = self._vmap_actor_network00( tensordict_actor, - self._cached_stack_actor_params, + self._cached_stack_actor_params.to_tensordict(), ) # add noise to target policy - actor_output_td1 = actor_output_td[1] - next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp( - self.min_action, self.max_action - ) + with torch.no_grad(): + actor_output_td1 = actor_output_td[1] + next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp( + self.min_action, self.max_action + ) actor_output_td1.set(self.tensor_keys.action, next_action) actor_output_td = torch.stack([actor_output_td[0], actor_output_td1], 0) tensordict_actor.set( @@ -405,7 +406,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # cat params qvalue_params = torch.cat( [ - self._cached_detach_qvalue_network_params, + self.qvalue_network_params, # self._cached_detach_qvalue_network_params, self.target_qvalue_network_params, self.qvalue_network_params, ], @@ -428,7 +429,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: dim=0, ) - loss_actor = -(state_action_value_actor.min(0)[0]).mean() + loss_actor = -(state_action_value_actor[0]).mean() # .min(0) next_state_value = next_state_action_value_qvalue.min(0)[0] tensordict.set( @@ -446,7 +447,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) .mean(-1) .sum() - * 0.5 ) tensordict_save.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) From 117c4779792e82587faeb3b3923aed0b856faba1 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 7 Sep 2023 14:45:40 +0200 Subject: [PATCH 03/28] fix objective --- torchrl/objectives/td3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index e03cd5a1ade..992d0573373 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -349,7 +349,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # computing early for reprod noise = torch.normal( mean=torch.zeros(action_shape), - std=torch.full(action_shape, self.policy_noise), + std=torch.full(action_shape, self.max_action * self.policy_noise), ).to(action_device) noise = noise.clamp(-self.noise_clip, self.noise_clip) From d2b3ad49c78423ad25a3ca526d0c93e7583c598c Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 12 Sep 2023 12:06:21 +0200 Subject: [PATCH 04/28] Update initial frames and general structure --- examples/td3/config.yaml | 20 ++++----- examples/td3/td3.py | 95 +++++++++++++++++++++------------------ examples/td3/utils.py | 21 ++++----- torchrl/objectives/td3.py | 12 ++--- 4 files changed, 78 insertions(+), 70 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index d46933e539d..50ccd8d3edf 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,16 +1,16 @@ -# Environment +# task and env env: - name: HalfCheetah-v3 + name: HalfCheetah-v2 task: "" - exp_name: "HalfCheetah-TD3-ICLR" + exp_name: "HalfCheetah-TD2-ICLR" library: gym frame_skip: 1 seed: 42 -# Collection +# collector collector: total_frames: 3000000 - init_random_frames: 10000 + init_random_frames: 25_000 init_env_steps: 1000 frames_per_batch: 1000 max_frames_per_traj: 1000 @@ -19,13 +19,13 @@ collector: env_per_collector: 1 num_workers: 1 -# Replay Buffer +# replay buffer replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 -# Optimization -optimization: +# optim +optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 @@ -37,13 +37,13 @@ optimization: policy_noise: 0.2 noise_clip: 0.5 -# Network +# network network: hidden_sizes: [256, 256] activation: relu device: "cuda:0" -# Logging +# logging logger: backend: wandb mode: online diff --git a/examples/td3/td3.py b/examples/td3/td3.py index f4d8707f404..880b803ae66 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -11,8 +11,9 @@ The helper functions are coded in the utils.py associated with this script. """ -import hydra +import time +import hydra import numpy as np import torch import torch.cuda @@ -57,49 +58,45 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create TD3 loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Make Off-Policy Collector + # Create Off-Policy Collector collector = make_collector(cfg, train_env, exploration_policy) - # Make Replay Buffer + # Create Replay Buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, device=device, ) - # Make Optimizers + # Create Optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) - rewards = [] - rewards_eval = [] - # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) - delayed_updates = cfg.optimization.policy_update_delay + delayed_updates = cfg.optim.policy_update_delay prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip eval_iter = cfg.logger.eval_iter frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip - for i, tensordict in enumerate(collector): + sampling_start = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start exploration_policy.step(tensordict.numel()) + # update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() pbar.update(tensordict.numel()) tensordict = tensordict.reshape(-1) @@ -108,12 +105,13 @@ def main(cfg: "DictConfig"): # noqa: F821 collected_frames += current_frames # optimization steps + training_start = time.time() if collected_frames >= init_random_frames: ( actor_losses, q_losses, ) = ([], []) - for j in range(num_updates): + for i in range(num_updates): # sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() @@ -123,7 +121,7 @@ def main(cfg: "DictConfig"): # noqa: F821 q_loss = loss_td["loss_qvalue"] optimizer_critic.zero_grad() - update_actor = j % delayed_updates == 0 + update_actor = i % delayed_updates == 0 q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) @@ -132,6 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() + actor_losses.append(actor_loss.item()) # update qnet_target params @@ -141,44 +140,52 @@ def main(cfg: "DictConfig"): # noqa: F821 if prb: replay_buffer.update_priority(sampled_tensordict) - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) - ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - } + training_time = time.time() - training_start + episode_rewards = tensordict["next", "episode_reward"][ + tensordict["next", "done"] + ] + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + logger.log_scalar( + "train/reward", episode_rewards.mean().item(), collected_frames + ) + logger.log_scalar( + "train/episode_length", + episode_length.sum().item() / len(episode_length), + collected_frames, ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) + + if collected_frames >= init_random_frames: + logger.log_scalar("train/q_loss", np.mean(q_losses), step=collected_frames) + logger.log_scalar( + "train/a_loss", np.mean(actor_losses), step=collected_frames + ) + logger.log_scalar("train/sampling_time", sampling_time, collected_frames) + logger.log_scalar("train/training_time", training_time, collected_frames) + + # evaluation if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + logger.log_scalar("eval/reward", eval_reward, step=collected_frames) + logger.log_scalar("eval/time", eval_time, step=collected_frames) + + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 1479ed2f670..4419900cc8e 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -10,6 +10,7 @@ EnvCreator, InitTracker, ParallelEnv, + RewardSum, TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv @@ -44,6 +45,7 @@ def apply_env_transforms(env, reward_scaling=1.0): InitTracker(), RewardScaling(loc=0.0, scale=reward_scaling), DoubleToFloat("observation"), + RewardSum(), ), ) return transformed_env @@ -79,6 +81,7 @@ def make_collector(cfg, train_env, actor_model_explore): collector = SyncDataCollector( train_env, actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, @@ -222,19 +225,17 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, delay_actor=True, delay_qvalue=True, action_spec=model[0][1].spec, - policy_noise=cfg.optimization.policy_noise, - policy_noise_clip=cfg.optimization.policy_noise_clip, + policy_noise=cfg.optim.policy_noise, + noise_clip=cfg.optim.noise_clip, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak - ) + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) return loss_module, target_net_updater @@ -243,11 +244,11 @@ def make_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optimization.lr, weight_decay=cfg.optimization.weight_decay + actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay ) optimizer_critic = optim.Adam( critic_params, - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, ) return optimizer_actor, optimizer_critic diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 992d0573373..d8ca7c3857a 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -347,11 +347,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: action_shape = act.shape action_device = act.device # computing early for reprod - noise = torch.normal( - mean=torch.zeros(action_shape), - std=torch.full(action_shape, self.max_action * self.policy_noise), - ).to(action_device) - noise = noise.clamp(-self.noise_clip, self.noise_clip) + noise = ( + (torch.randn(action_shape) * self.policy_noise) + .clamp(-self.noise_clip, self.noise_clip) + .to(action_device) + ) tensordict_actor_grad = tensordict.select( *obs_keys @@ -406,7 +406,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # cat params qvalue_params = torch.cat( [ - self.qvalue_network_params, # self._cached_detach_qvalue_network_params, + self._cached_detach_qvalue_network_params, # self.qvalue_network_params, # self.target_qvalue_network_params, self.qvalue_network_params, ], From 9c6c358edb8939dd8597e0b2eb6c06349af5c925 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 12 Sep 2023 12:10:07 +0200 Subject: [PATCH 05/28] fixes --- examples/td3/td3.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 880b803ae66..9a5c803b3d3 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -36,6 +36,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + # Create Logger exp_name = generate_exp_name("TD3", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -46,6 +47,7 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) + # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) @@ -101,6 +103,7 @@ def main(cfg: "DictConfig"): # noqa: F821 tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + # add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames @@ -115,17 +118,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() + # compute loss loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] + # update critic optimizer_critic.zero_grad() update_actor = i % delayed_updates == 0 q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) + # update actor if update_actor: optimizer_actor.zero_grad() actor_loss.backward() @@ -133,7 +139,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_losses.append(actor_loss.item()) - # update qnet_target params + # update target params target_net_updater.step() # update priority From 2422ef8d285f32281b46a1ea2fdf5aef10dd9f0e Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 12 Sep 2023 12:18:46 +0200 Subject: [PATCH 06/28] naming fix --- examples/td3/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 50ccd8d3edf..aa589f6dd4a 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -2,7 +2,7 @@ env: name: HalfCheetah-v2 task: "" - exp_name: "HalfCheetah-TD2-ICLR" + exp_name: "HalfCheetah-TD3-ICLR" library: gym frame_skip: 1 seed: 42 From 0e67de241620d12f8ed12408e650cba01e3d30cf Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 13 Sep 2023 10:17:50 +0200 Subject: [PATCH 07/28] single step td3 --- examples/td3/config.yaml | 4 ++-- examples/td3/td3.py | 11 ++++++----- torchrl/objectives/td3.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index aa589f6dd4a..9029c952051 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -5,14 +5,14 @@ env: exp_name: "HalfCheetah-TD3-ICLR" library: gym frame_skip: 1 - seed: 42 + seed: 0 #42 # collector collector: total_frames: 3000000 init_random_frames: 25_000 init_env_steps: 1000 - frames_per_batch: 1000 + frames_per_batch: 1 max_frames_per_traj: 1000 async_collection: 1 collector_device: cpu diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 9a5c803b3d3..d6429dd45e0 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -114,7 +114,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_losses, q_losses, ) = ([], []) - for i in range(num_updates): + for _ in range(num_updates): # sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # update critic optimizer_critic.zero_grad() - update_actor = i % delayed_updates == 0 + update_actor = collected_frames % delayed_updates == 0 q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) @@ -165,9 +165,10 @@ def main(cfg: "DictConfig"): # noqa: F821 if collected_frames >= init_random_frames: logger.log_scalar("train/q_loss", np.mean(q_losses), step=collected_frames) - logger.log_scalar( - "train/a_loss", np.mean(actor_losses), step=collected_frames - ) + if update_actor: + logger.log_scalar( + "train/a_loss", np.mean(actor_losses), step=collected_frames + ) logger.log_scalar("train/sampling_time", sampling_time, collected_frames) logger.log_scalar("train/training_time", training_time, collected_frames) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index c283f161454..b177123e1b6 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -352,7 +352,7 @@ def _cached_stack_actor_params(self): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys - tensordict_save = tensordict + tensordict = tensordict.clone(False).to_tensordict() act = tensordict.get(self.tensor_keys.action) action_shape = act.shape @@ -403,6 +403,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: _next_val_td = ( tensordict_actor[1] .select(*self.qvalue_network.in_keys) + .detach() .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) ) # for next value estimation tensordict_qval = torch.cat( @@ -417,7 +418,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # cat params qvalue_params = torch.cat( [ - self._cached_detach_qvalue_network_params, # self.qvalue_network_params, # + self.qvalue_network_params, self.target_qvalue_network_params, self.qvalue_network_params, ], @@ -440,16 +441,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: dim=0, ) - loss_actor = -(state_action_value_actor[0]).mean() # .min(0) + loss_actor = -(state_action_value_actor[0]).mean() next_state_value = next_state_action_value_qvalue.min(0)[0] tensordict.set( ("next", self.tensor_keys.state_action_value), next_state_value.unsqueeze(-1), ) - target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + target_value = ( + self.value_estimator.value_estimate(tensordict).squeeze(-1).detach() + ) pred_val = state_action_value_qvalue - td_error = (pred_val - target_value).pow(2) loss_qval = ( distance_loss( pred_val, @@ -460,16 +462,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: .sum() ) - tensordict_save.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) - if not loss_qval.shape == loss_actor.shape: raise RuntimeError( f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" ) td_out = TensorDict( source={ - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, "pred_value": pred_val.mean().detach(), "state_action_value_actor": state_action_value_actor.mean().detach(), "next_state_value": next_state_value.mean().detach(), From 1fc08478556a5a92a416e849877497e1c363d2c1 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 14:42:23 +0200 Subject: [PATCH 08/28] small fixes --- examples/td3/config.yaml | 4 ++-- examples/td3/utils.py | 1 + torchrl/objectives/td3.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 9029c952051..bc91e575da4 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,11 +1,11 @@ # task and env env: - name: HalfCheetah-v2 + name: HalfCheetah-v3 task: "" exp_name: "HalfCheetah-TD3-ICLR" library: gym frame_skip: 1 - seed: 0 #42 + seed: 42 # collector collector: diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 4419900cc8e..a586c25d37d 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -228,6 +228,7 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, delay_actor=True, delay_qvalue=True, + gamma=cfg.optim.gamma, action_spec=model[0][1].spec, policy_noise=cfg.optim.policy_noise, noise_clip=cfg.optim.noise_clip, diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index b177123e1b6..9b46175b001 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -418,7 +418,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # cat params qvalue_params = torch.cat( [ - self.qvalue_network_params, + self._cached_detach_qvalue_network_params, self.target_qvalue_network_params, self.qvalue_network_params, ], From 7a02b830fc0221078c7103253d0866841ecbd68f Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 14:43:18 +0200 Subject: [PATCH 09/28] fix --- examples/td3/td3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index d6429dd45e0..bcb98b9117a 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -150,6 +150,8 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] + + # logging if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][ tensordict["next", "done"] From 243d7126f92a7bafb8dda2369f6fba6b9b5e5dd3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 15:00:09 +0200 Subject: [PATCH 10/28] add update counter --- examples/td3/td3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index bcb98b9117a..cb755b85536 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -90,6 +90,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip eval_iter = cfg.logger.eval_iter frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip + update_counter = 0 sampling_start = time.time() for tensordict in collector: @@ -115,6 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 q_losses, ) = ([], []) for _ in range(num_updates): + update_counter += 1 # sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() @@ -126,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # update critic optimizer_critic.zero_grad() - update_actor = collected_frames % delayed_updates == 0 + update_actor = update_counter % delayed_updates == 0 q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) From af31bd9ba1cce2e6b1a55820cbe9e49377a4e64b Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 15:10:01 +0200 Subject: [PATCH 11/28] naming fixes --- examples/td3/td3.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index cb755b85536..5e77273bd28 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -36,7 +36,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) - # Create Logger + # Create logger exp_name = generate_exp_name("TD3", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -51,19 +51,19 @@ def main(cfg: "DictConfig"): # noqa: F821 torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - # Create Environments + # Create environments train_env, eval_env = make_environment(cfg) - # Create Agent + # Create agent model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) # Create TD3 loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Create Off-Policy Collector + # Create off-policy collector collector = make_collector(cfg, train_env, exploration_policy) - # Create Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, @@ -71,7 +71,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device=device, ) - # Create Optimizers + # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) # Main loop @@ -97,18 +97,18 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_time = time.time() - sampling_start exploration_policy.step(tensordict.numel()) - # update weights of the inference policy + # Update weights of the inference policy collector.update_policy_weights_() pbar.update(tensordict.numel()) tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # add to replay buffer + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - # optimization steps + # Optimization steps training_start = time.time() if collected_frames >= init_random_frames: ( @@ -117,23 +117,23 @@ def main(cfg: "DictConfig"): # noqa: F821 ) = ([], []) for _ in range(num_updates): update_counter += 1 - # sample from replay buffer + # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() - # compute loss + # Compute loss loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] - # update critic + # Update critic optimizer_critic.zero_grad() update_actor = update_counter % delayed_updates == 0 q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) - # update actor + # Update actor if update_actor: optimizer_actor.zero_grad() actor_loss.backward() @@ -141,10 +141,10 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_losses.append(actor_loss.item()) - # update target params + # Update target params target_net_updater.step() - # update priority + # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) @@ -153,7 +153,7 @@ def main(cfg: "DictConfig"): # noqa: F821 tensordict["next", "done"] ] - # logging + # Logging if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][ tensordict["next", "done"] @@ -176,7 +176,7 @@ def main(cfg: "DictConfig"): # noqa: F821 logger.log_scalar("train/sampling_time", sampling_time, collected_frames) logger.log_scalar("train/training_time", training_time, collected_frames) - # evaluation + # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() From 112280874d747102b8bdf524d7e64766b85261e9 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 15 Sep 2023 09:10:54 +0200 Subject: [PATCH 12/28] update logging and small fixes --- examples/td3/config.yaml | 4 ++-- examples/td3/td3.py | 32 ++++++++++++++++---------------- examples/td3/utils.py | 32 +++++++++++++++++++++----------- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index bc91e575da4..433b395b1df 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -2,7 +2,7 @@ env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-TD3-ICLR" + exp_name: "HalfCheetah-TD3" library: gym frame_skip: 1 seed: 42 @@ -12,7 +12,7 @@ collector: total_frames: 3000000 init_random_frames: 25_000 init_env_steps: 1000 - frames_per_batch: 1 + frames_per_batch: 1000 max_frames_per_traj: 1000 async_collection: 1 collector_device: cpu diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 5e77273bd28..96d6c170068 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, make_environment, make_loss_module, @@ -116,7 +117,11 @@ def main(cfg: "DictConfig"): # noqa: F821 q_losses, ) = ([], []) for _ in range(num_updates): + + # Update actor every delayed_updates update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() @@ -128,7 +133,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update critic optimizer_critic.zero_grad() - update_actor = update_counter % delayed_updates == 0 q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) @@ -154,27 +158,22 @@ def main(cfg: "DictConfig"): # noqa: F821 ] # Logging + metrics_to_log = {} if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][ tensordict["next", "done"] ] - logger.log_scalar( - "train/reward", episode_rewards.mean().item(), collected_frames - ) - logger.log_scalar( - "train/episode_length", - episode_length.sum().item() / len(episode_length), - collected_frames, + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) if collected_frames >= init_random_frames: - logger.log_scalar("train/q_loss", np.mean(q_losses), step=collected_frames) + metrics_to_log["train/q_loss"] = np.mean(q_losses) if update_actor: - logger.log_scalar( - "train/a_loss", np.mean(actor_losses), step=collected_frames - ) - logger.log_scalar("train/sampling_time", sampling_time, collected_frames) - logger.log_scalar("train/training_time", training_time, collected_frames) + metrics_to_log["train/a_loss"] = np.mean(actor_losses) + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: @@ -188,9 +187,10 @@ def main(cfg: "DictConfig"): # noqa: F821 ) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - logger.log_scalar("eval/reward", eval_reward, step=collected_frames) - logger.log_scalar("eval/time", eval_time, step=collected_frames) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + log_metrics(logger, metrics_to_log, collected_frames) sampling_start = time.time() collector.shutdown() diff --git a/examples/td3/utils.py b/examples/td3/utils.py index a586c25d37d..b4f4e8979a0 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -131,17 +131,6 @@ def make_replay_buffer( # ----- -def get_activation(cfg): - if cfg.network.activation == "relu": - return nn.ReLU - elif cfg.network.activation == "tanh": - return nn.Tanh - elif cfg.network.activation == "leaky_relu": - return nn.LeakyReLU - else: - raise NotImplementedError - - def make_td3_agent(cfg, train_env, eval_env, device): """Make TD3 agent.""" # Define Actor Network @@ -253,3 +242,24 @@ def make_optimizer(cfg, loss_module): weight_decay=cfg.optim.weight_decay, ) return optimizer_actor, optimizer_critic + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError From b4df32bccca62c03cfc4c6b734bff4540e3fb288 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 18 Sep 2023 10:37:02 +0200 Subject: [PATCH 13/28] no eps --- examples/td3/config.yaml | 1 + examples/td3/utils.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 433b395b1df..1e051cafcea 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -31,6 +31,7 @@ optim: loss_function: l2 lr: 3.0e-4 weight_decay: 0.0 + adam_eps: 1e-8 batch_size: 256 target_update_polyak: 0.995 policy_update_delay: 2 diff --git a/examples/td3/utils.py b/examples/td3/utils.py index b4f4e8979a0..27fc420de5b 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -234,12 +234,16 @@ def make_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) optimizer_critic = optim.Adam( critic_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) return optimizer_actor, optimizer_critic From 13f367a1250bd40e8326ec84e29c141da042b468 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 19 Sep 2023 10:53:23 +0200 Subject: [PATCH 14/28] update tests --- .github/unittest/linux_examples/scripts/run_test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index d81e90fdd42..cc28e8db80d 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -138,7 +138,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optimization.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=4 \ collector.env_per_collector=2 \ @@ -259,7 +259,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optimization.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=2 \ collector.env_per_collector=1 \ From 72ddf7e2492fffd5c22c1679a9cca8068db3ac81 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 20 Sep 2023 18:42:46 +0200 Subject: [PATCH 15/28] update objective --- torchrl/objectives/td3.py | 183 +++++++++++++++++++------------------- 1 file changed, 90 insertions(+), 93 deletions(-) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 9b46175b001..ec0fa914987 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -349,119 +349,118 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - obs_keys = self.actor_network.in_keys + def actor_loss(self, tensordict): + tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys) + tensordict_actor_grad = self.actor_network( + tensordict_actor_grad, self.actor_network_params + ) + actor_loss_td = tensordict_actor_grad.select( + *self.qvalue_network.in_keys + ).expand( + self.num_qvalue_nets, *tensordict_actor_grad.batch_size + ) # for actor loss + state_action_value_actor = ( + self._vmap_qvalue_network00( + actor_loss_td, + self.qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + loss_actor = -(state_action_value_actor[0]).mean() + metadata = { + "state_action_value_actor": state_action_value_actor.mean().detach(), + } + return loss_actor, metadata + + def value_loss(self, tensordict): + tensordict = tensordict.clone(False) - tensordict = tensordict.clone(False).to_tensordict() act = tensordict.get(self.tensor_keys.action) - action_shape = act.shape - action_device = act.device - # computing early for reprod - noise = ( - (torch.randn(action_shape) * self.policy_noise) - .clamp(-self.noise_clip, self.noise_clip) - .to(action_device) - ) - tensordict_actor_grad = tensordict.select( - *obs_keys - ) # to avoid overwriting keys - next_td_actor = step_mdp(tensordict).select( - *self.actor_network.in_keys - ) # next_observation -> - tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - # DO NOT call contiguous bc we'll update the tds later - actor_output_td = self._vmap_actor_network00( - tensordict_actor, - self._cached_stack_actor_params.to_tensordict(), + # computing early for reprod + noise = (torch.randn_like(act) * self.policy_noise).clamp( + -self.noise_clip, self.noise_clip ) - # add noise to target policy with torch.no_grad(): - actor_output_td1 = actor_output_td[1] - next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp( + next_td_actor = step_mdp(tensordict).select( + *self.actor_network.in_keys + ) # next_observation -> + next_td_actor = self.actor_network( + next_td_actor, self.target_actor_network_params + ) + next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( self.min_action, self.max_action ) - actor_output_td1.set(self.tensor_keys.action, next_action) - actor_output_td = torch.stack([actor_output_td[0], actor_output_td1], 0) - tensordict_actor.set( - self.tensor_keys.action, - actor_output_td.get(self.tensor_keys.action), + next_td_actor.set( + self.tensor_keys.action, + next_action, + ) + next_val_td = next_td_actor.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, *next_td_actor.batch_size + ) # for next value estimation + next_target_q1q2 = ( + self._vmap_qvalue_network00( + next_val_td, + self.target_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + # min over the next target qvalues + next_target_qvalue = next_target_q1q2.min(0)[0] + + # set next target qvalues + tensordict.set( + ("next", self.tensor_keys.state_action_value), + next_target_qvalue.unsqueeze(-1), ) - # repeat tensordict_actor to match the qvalue size - _actor_loss_td = ( - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) - ) # for actor loss - _qval_td = tensordict.select(*self.qvalue_network.in_keys).expand( + qval_td = tensordict.select(*self.qvalue_network.in_keys).expand( self.num_qvalue_nets, - *tensordict.select(*self.qvalue_network.in_keys).batch_size, - ) # for qvalue loss - _next_val_td = ( - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .detach() - .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) - ) # for next value estimation - tensordict_qval = torch.cat( - [ - _actor_loss_td, - _next_val_td, - _qval_td, - ], - 0, + *tensordict.batch_size, ) - - # cat params - qvalue_params = torch.cat( - [ - self._cached_detach_qvalue_network_params, - self.target_qvalue_network_params, + # preditcted current qvalues + current_qvalue = ( + self._vmap_qvalue_network00( + qval_td, self.qvalue_network_params, - ], - 0, - ) - tensordict_qval = self._vmap_qvalue_network00( - tensordict_qval, - qvalue_params, - ) - - state_action_value = tensordict_qval.get( - self.tensor_keys.state_action_value - ).squeeze(-1) - ( - state_action_value_actor, - next_state_action_value_qvalue, - state_action_value_qvalue, - ) = state_action_value.split( - [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], - dim=0, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) ) - loss_actor = -(state_action_value_actor[0]).mean() + # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done)) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - next_state_value = next_state_action_value_qvalue.min(0)[0] - tensordict.set( - ("next", self.tensor_keys.state_action_value), - next_state_value.unsqueeze(-1), - ) - target_value = ( - self.value_estimator.value_estimate(tensordict).squeeze(-1).detach() - ) - pred_val = state_action_value_qvalue + td_error = (current_qvalue - target_value).pow(2) loss_qval = ( distance_loss( - pred_val, - target_value.expand_as(pred_val), + current_qvalue, + target_value.expand_as(current_qvalue), loss_function=self.loss_function, ) .mean(-1) .sum() ) + metadata = { + "td_error": td_error, + "pred_value": current_qvalue.mean().detach(), + "next_state_value": next_target_qvalue.mean().detach(), + "target_value": target_value.mean().detach(), + } + return loss_qval, metadata + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict_save = tensordict + loss_qval, metadata_value = self.value_loss(tensordict) + loss_actor, metadata_actor = self.actor_loss(tensordict) + tensordict_save.set( + self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] + ) if not loss_qval.shape == loss_actor.shape: raise RuntimeError( f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" @@ -470,10 +469,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: source={ "loss_actor": loss_actor, "loss_qvalue": loss_qval, - "pred_value": pred_val.mean().detach(), - "state_action_value_actor": state_action_value_actor.mean().detach(), - "next_state_value": next_state_value.mean().detach(), - "target_value": target_value.mean().detach(), + **metadata_actor, + **metadata_value, }, batch_size=[], ) From c830891389fca773aec4317faced41564de384c0 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 20 Sep 2023 18:55:59 +0200 Subject: [PATCH 16/28] set gym backend --- examples/td3/utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 27fc420de5b..64bbe349b0f 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import torch from torch import nn, optim @@ -13,7 +18,7 @@ RewardSum, TransformedEnv, ) -from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.transforms import RewardScaling from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( @@ -35,7 +40,10 @@ def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) + with set_gym_backend("gym"): + return GymEnv( + task, device=device, frame_skip=frame_skip, from_pixels=from_pixels + ) def apply_env_transforms(env, reward_scaling=1.0): From 4cdbb3bf11fba7f1f924941b5ead4611f45ae187 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Sep 2023 18:21:49 +0200 Subject: [PATCH 17/28] update tests --- test/test_cost.py | 6 ++++-- torchrl/objectives/td3.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 6b940331f7e..ef1b655f31d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -2357,8 +2357,10 @@ def test_td3_notensordict(self, observation_key, reward_key, done_key): loss_val_td = loss(td) torch.manual_seed(0) loss_val = loss(**kwargs) - for i, key in enumerate(loss_val_td.keys()): - torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + for i in loss_val: + assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" + # for i, key in enumerate(loss_val_td.keys()): + # torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) # test select loss.select_out_keys("loss_actor", "loss_qvalue") torch.manual_seed(0) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index ec0fa914987..b6495605468 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -362,7 +362,7 @@ def actor_loss(self, tensordict): state_action_value_actor = ( self._vmap_qvalue_network00( actor_loss_td, - self.qvalue_network_params, + self._cached_detach_qvalue_network_params, ) .get(self.tensor_keys.state_action_value) .squeeze(-1) @@ -446,8 +446,8 @@ def value_loss(self, tensordict): ) metadata = { "td_error": td_error, - "pred_value": current_qvalue.mean().detach(), "next_state_value": next_target_qvalue.mean().detach(), + "pred_value": current_qvalue.mean().detach(), "target_value": target_value.mean().detach(), } @@ -456,8 +456,8 @@ def value_loss(self, tensordict): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_save = tensordict - loss_qval, metadata_value = self.value_loss(tensordict) loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.value_loss(tensordict_save) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] ) From 76dcdebb05e3172121db192a6ea5f8e9bc6f77c3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 22 Sep 2023 13:58:43 +0200 Subject: [PATCH 18/28] update fix max episode steps --- examples/td3/config.yaml | 2 +- examples/td3/td3.py | 11 ++++++----- examples/td3/utils.py | 22 ++++++++++++++++++---- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 1e051cafcea..2ae2f84affa 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -6,6 +6,7 @@ env: library: gym frame_skip: 1 seed: 42 + max_episode_steps: 1_000_000 # collector collector: @@ -14,7 +15,6 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 max_frames_per_traj: 1000 - async_collection: 1 collector_device: cpu env_per_collector: 1 num_workers: 1 diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 96d6c170068..62469f193f9 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -153,16 +153,17 @@ def main(cfg: "DictConfig"): # noqa: F821 replay_buffer.update_priority(sampled_tensordict) training_time = time.time() - training_start - episode_rewards = tensordict["next", "episode_reward"][ + episode_end = ( tensordict["next", "done"] - ] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] # Logging metrics_to_log = {} if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] + episode_length = tensordict["next", "step_count"][episode_end] metrics_to_log["train/reward"] = episode_rewards.mean().item() metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( episode_length diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 64bbe349b0f..34f04901e94 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -39,10 +39,16 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): +def env_maker( + task, frame_skip=1, device="cpu", from_pixels=False, max_episode_steps=1000 +): with set_gym_backend("gym"): return GymEnv( - task, device=device, frame_skip=frame_skip, from_pixels=from_pixels + task, + device=device, + frame_skip=frame_skip, + from_pixels=from_pixels, + max_episode_steps=max_episode_steps, ) @@ -63,7 +69,11 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator( + lambda: env_maker( + task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps + ) + ), ) parallel_env.set_seed(cfg.env.seed) @@ -72,7 +82,11 @@ def make_environment(cfg): eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator( + lambda: env_maker( + task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps + ) + ), ), train_env.transform.clone(), ) From ec8b089f620422cfcc4ad107c56a95c4753f770b Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 27 Sep 2023 08:39:07 +0200 Subject: [PATCH 19/28] fix --- examples/td3/config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 2ae2f84affa..1c17b49c167 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,16 +1,16 @@ # task and env env: - name: HalfCheetah-v3 + name: Hopper-v3 task: "" - exp_name: "HalfCheetah-TD3" + exp_name: "Hopper-TD3" library: gym frame_skip: 1 seed: 42 - max_episode_steps: 1_000_000 + max_episode_steps: 1000 # collector collector: - total_frames: 3000000 + total_frames: 1000000 init_random_frames: 25_000 init_env_steps: 1000 frames_per_batch: 1000 From bcc3bc62a01fbd3600da8386e9b7bbf3e98413d3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 27 Sep 2023 10:50:46 +0200 Subject: [PATCH 20/28] fix --- examples/td3/config.yaml | 7 ++++--- examples/td3/utils.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 1c17b49c167..d489db506b9 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,12 +1,12 @@ # task and env env: - name: Hopper-v3 + name: Walker2d-v3 task: "" - exp_name: "Hopper-TD3" + exp_name: "Walker2d-TD3" library: gym frame_skip: 1 seed: 42 - max_episode_steps: 1000 + max_episode_steps: 5000 # collector collector: @@ -15,6 +15,7 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 max_frames_per_traj: 1000 + reset_at_each_iter: False collector_device: cpu env_per_collector: 1 num_workers: 1 diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 34f04901e94..8fb8b2d55e1 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -107,6 +107,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, + reset_at_each_iter=cfg.collector.reset_at_each_iter, device=cfg.collector.collector_device, ) collector.set_seed(cfg.env.seed) From 42748e0581b0682d742db27fb972fa1802f5d16a Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 28 Sep 2023 04:54:44 -0400 Subject: [PATCH 21/28] amend --- examples/td3/utils.py | 69 +++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 64bbe349b0f..2fc60a7931e 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import tempfile +from contextlib import nullcontext import torch @@ -39,11 +41,9 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): +def env_maker(task, device="cpu"): with set_gym_backend("gym"): - return GymEnv( - task, device=device, frame_skip=frame_skip, from_pixels=from_pixels - ) + return GymEnv(task, device=device) def apply_env_transforms(env, reward_scaling=1.0): @@ -63,7 +63,7 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), ) parallel_env.set_seed(cfg.env.seed) @@ -72,7 +72,7 @@ def make_environment(cfg): eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), ), train_env.transform.clone(), ) @@ -103,35 +103,40 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir="/tmp/", + buffer_scratch_dir=None, device="cpu", prefetch=3, ): - if prb: - replay_buffer = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - batch_size=batch_size, - ) - else: - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - batch_size=batch_size, - ) - return replay_buffer + with ( + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) + ) as scratch_dir: + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + return replay_buffer # ==================================================================== From e2c28c8bd15e740963641d68f5c15cbeff02fb2b Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 28 Sep 2023 05:40:08 -0400 Subject: [PATCH 22/28] amend --- examples/td3/config.yaml | 6 +-- examples/td3/td3.py | 10 ++--- examples/td3/utils.py | 34 ++++++---------- torchrl/collectors/collectors.py | 58 +++++++++++++++++---------- torchrl/envs/transforms/transforms.py | 7 ++++ 5 files changed, 63 insertions(+), 52 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 2ae2f84affa..b6a3c567cea 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -6,7 +6,7 @@ env: library: gym frame_skip: 1 seed: 42 - max_episode_steps: 1_000_000 + max_episode_steps: 1000 # collector collector: @@ -14,7 +14,7 @@ collector: init_random_frames: 25_000 init_env_steps: 1000 frames_per_batch: 1000 - max_frames_per_traj: 1000 + max_frames_per_traj: collector_device: cpu env_per_collector: 1 num_workers: 1 @@ -31,7 +31,7 @@ optim: loss_function: l2 lr: 3.0e-4 weight_decay: 0.0 - adam_eps: 1e-8 + adam_eps: 1e-4 batch_size: 256 target_update_polyak: 0.995 policy_update_delay: 2 diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 62469f193f9..f937b818413 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -88,7 +88,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) delayed_updates = cfg.optim.policy_update_delay prb = cfg.replay_buffer.prb - eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip + eval_rollout_steps = cfg.env.max_episode_steps // cfg.env.frame_skip eval_iter = cfg.logger.eval_iter frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip update_counter = 0 @@ -126,19 +126,17 @@ def main(cfg: "DictConfig"): # noqa: F821 sampled_tensordict = replay_buffer.sample().clone() # Compute loss - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] + q_loss, *_ = loss_module.value_loss(sampled_tensordict) # Update critic optimizer_critic.zero_grad() - q_loss.backward(retain_graph=update_actor) + q_loss.backward() optimizer_critic.step() q_losses.append(q_loss.item()) # Update actor if update_actor: + actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() diff --git a/examples/td3/utils.py b/examples/td3/utils.py index fff1bba1de6..414321c2c2f 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -18,6 +18,7 @@ InitTracker, ParallelEnv, RewardSum, + StepCounter, TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv, set_gym_backend @@ -41,27 +42,26 @@ # ----------------- -def env_maker( - task, device="cpu", max_episode_steps=1000 -): +def env_maker(task, device="cpu"): with set_gym_backend("gym"): return GymEnv( task, device=device, - max_episode_steps=max_episode_steps, ) -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms(env, max_episode_steps, reward_scaling=1.0): transformed_env = TransformedEnv( env, Compose( + StepCounter(max_steps=max_episode_steps), InitTracker(), - RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat("observation"), - RewardSum(), + DoubleToFloat(), ), ) + if reward_scaling != 1.0: + transformed_env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + transformed_env.append_transform(RewardSum()) return transformed_env @@ -69,26 +69,18 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator( - lambda task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps: env_maker( - task=task, max_episode_steps=max_episode_steps - ) - ), + EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms(parallel_env) + train_env = apply_env_transforms( + parallel_env, max_episode_steps=cfg.env.max_episode_steps + ) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator( - lambda - task=cfg.env.name, - max_episode_steps=cfg.env.max_episode_steps: env_maker( - task=task, max_episode_steps=max_episode_steps - ) - ), + EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), ), train_env.transform.clone(), ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 5057d42119b..2841a346c75 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import _pickle import abc import inspect @@ -390,12 +392,12 @@ class SyncDataCollector(DataCollectorBase): If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e. no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -497,12 +499,12 @@ def __init__( total_frames: int, device: DEVICE_TYPING = None, storing_device: DEVICE_TYPING = None, - create_env_kwargs: Optional[dict] = None, - max_frames_per_traj: int = -1, - init_random_frames: int = -1, + create_env_kwargs: dict | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - split_trajs: Optional[bool] = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, exploration_mode=None, return_same_td: bool = False, @@ -566,7 +568,7 @@ def __init__( self.env: EnvBase = self.env.to(self.device) self.max_frames_per_traj = max_frames_per_traj - if self.max_frames_per_traj > 0: + if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: # let's check that there is no StepCounter yet for key in self.env.output_spec.keys(True, True): if isinstance(key, str): @@ -869,7 +871,10 @@ def rollout(self) -> TensorDictBase: tensordicts = [] with set_exploration_type(self.exploration_type): for t in range(self.frames_per_batch): - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): self.env.rand_step(self._tensordict) else: self.policy(self._tensordict) @@ -1062,12 +1067,12 @@ class _MultiDataCollector(DataCollectorBase): If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e. no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -1123,8 +1128,8 @@ def __init__( device: DEVICE_TYPING = None, storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, create_env_kwargs: Optional[Sequence[dict]] = None, - max_frames_per_traj: int = -1, - init_random_frames: int = -1, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, @@ -1677,7 +1682,10 @@ def iterator(self) -> Iterator[TensorDictBase]: self.update_policy_weights_() for idx in range(self.num_workers): - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): msg = "continue_random" else: msg = "continue" @@ -1913,7 +1921,7 @@ def iterator(self) -> Iterator[TensorDictBase]: self.update_policy_weights_() for i in range(self.num_workers): - if self.init_random_frames > 0: + if self.init_random_frames is not None and self.init_random_frames > 0: self.pipes[i].send((None, "continue_random")) else: self.pipes[i].send((None, "continue")) @@ -1935,7 +1943,10 @@ def iterator(self) -> Iterator[TensorDictBase]: # the function blocks here until the next item is asked, hence we send the message to the # worker to keep on working in the meantime before the yield statement - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): msg = "continue_random" else: msg = "continue" @@ -1962,7 +1973,10 @@ def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: raise Exception("self.queue_out is full") if self.running: for idx in range(self.num_workers): - if self._frames < self.init_random_frames: + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): self.pipes[idx].send((idx, "continue_random")) else: self.pipes[idx].send((idx, "continue")) @@ -1996,14 +2010,14 @@ class aSyncDataCollector(MultiaSyncDataCollector): environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Default is -1 (i.e. no maximum number of steps) + Defaults to ``None`` (i.e. no maximum number of steps) frames_per_batch (int): Time-length of a batch. reset_at_each_iter and frames_per_batch == n_steps are equivalent configurations. - default: 200 + Defaults to ``200`` init_random_frames (int): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - default=-1 (i.e. no random frames) + Defaults to ``None`` (i.e. no random frames) reset_at_each_iter (bool): Whether or not environments should be reset for each batch. default=False. postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a @@ -2038,9 +2052,9 @@ def __init__( ] = None, total_frames: Optional[int] = -1, create_env_kwargs: Optional[dict] = None, - max_frames_per_traj: int = -1, + max_frames_per_traj: int | None = None, frames_per_batch: int = 200, - init_random_frames: int = -1, + init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d491d3bfa03..e0370d2be50 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2830,11 +2830,18 @@ def _set_in_keys(self): self._keys_inv_unset = False self._container.empty_cache() + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._keys_unset: + self._set_in_keys() + return super().reset(tensordict) + @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" if self._keys_unset: self._set_in_keys() + # if still no update + if self._keys_unset: for in_key, data in tensordict.items(True, True): if data.dtype == self.dtype_in: out_key = in_key From bb496ef9ff5518414d497d1c438873014bef6e37 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 28 Sep 2023 16:16:59 +0200 Subject: [PATCH 23/28] update scratch_dir, frame skip, config --- examples/td3/config.yaml | 4 ++-- examples/td3/td3.py | 7 ++++--- examples/td3/utils.py | 5 +---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index d489db506b9..6d5b6a299bd 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -2,9 +2,8 @@ env: name: Walker2d-v3 task: "" - exp_name: "Walker2d-TD3" + exp_name: ${env.name}_TD3 library: gym - frame_skip: 1 seed: 42 max_episode_steps: 5000 @@ -24,6 +23,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 + scratch_dir: ${env.exp_name}_${env.seed} # optim optim: diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 62469f193f9..e6744815bdd 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -69,6 +69,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir, device=device, ) @@ -88,9 +89,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) delayed_updates = cfg.optim.policy_update_delay prb = cfg.replay_buffer.prb - eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip + eval_rollout_steps = cfg.collector.max_frames_per_traj eval_iter = cfg.logger.eval_iter - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip + frames_per_batch = cfg.collector.frames_per_batch update_counter = 0 sampling_start = time.time() @@ -177,7 +178,7 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/training_time"] = training_time # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 8fb8b2d55e1..2b41700f004 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -39,14 +39,11 @@ # ----------------- -def env_maker( - task, frame_skip=1, device="cpu", from_pixels=False, max_episode_steps=1000 -): +def env_maker(task, device="cpu", from_pixels=False, max_episode_steps=1000): with set_gym_backend("gym"): return GymEnv( task, device=device, - frame_skip=frame_skip, from_pixels=from_pixels, max_episode_steps=max_episode_steps, ) From e622bf77695a4c54aa4d3f6b460ebfa691e5bca8 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 2 Oct 2023 10:38:00 +0200 Subject: [PATCH 24/28] merge main --- examples/td3/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 6d5b6a299bd..caa2d139123 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,6 +1,6 @@ # task and env env: - name: Walker2d-v3 + name: HalfCheetah-v3 task: "" exp_name: ${env.name}_TD3 library: gym From 29977df992ec2e0eba12c9e59e85fec6a1c963a4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 2 Oct 2023 11:01:26 +0200 Subject: [PATCH 25/28] step counter --- examples/td3/config.yaml | 1 - examples/td3/td3.py | 2 +- examples/td3/utils.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 43d2be9a711..3a0e01dfdf5 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -13,7 +13,6 @@ collector: init_random_frames: 25_000 init_env_steps: 1000 frames_per_batch: 1000 - max_frames_per_traj: 1000 reset_at_each_iter: False collector_device: cpu env_per_collector: 1 diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 1bda16fd09b..a6d1ee93617 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -89,7 +89,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) delayed_updates = cfg.optim.policy_update_delay prb = cfg.replay_buffer.prb - eval_rollout_steps = cfg.collector.max_frames_per_traj + eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch update_counter = 0 diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 54457724a94..4fcb3ae78c1 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -104,7 +104,6 @@ def make_collector(cfg, train_env, actor_model_explore): actor_model_explore, init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, device=cfg.collector.collector_device, From 619f2ea6c009c7fac6d9be21eb6e4367edc5e8e9 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 3 Oct 2023 09:24:31 +0200 Subject: [PATCH 26/28] small fixes --- examples/td3/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 4fcb3ae78c1..090529782fd 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -22,7 +22,6 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.transforms import RewardScaling from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( AdditiveGaussianWrapper, @@ -55,18 +54,16 @@ def env_maker( ) -def apply_env_transforms(env, max_episode_steps, reward_scaling=1.0): +def apply_env_transforms(env, max_episode_steps): transformed_env = TransformedEnv( env, Compose( StepCounter(max_steps=max_episode_steps), InitTracker(), DoubleToFloat(), + RewardSum(), ), ) - if reward_scaling != 1.0: - transformed_env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) - transformed_env.append_transform(RewardSum()) return transformed_env From 8d3678740a2ad92b029a6492b696af0dfe405fe9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 07:13:18 -0400 Subject: [PATCH 27/28] solve logger issue --- examples/td3/td3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index a6d1ee93617..7c9904f5300 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -189,8 +189,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward metrics_to_log["eval/time"] = eval_time - - log_metrics(logger, metrics_to_log, collected_frames) + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) sampling_start = time.time() collector.shutdown() From a24ab8d57038549365083bd168e72cd64c9573cc Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 07:19:51 -0400 Subject: [PATCH 28/28] reset notensordict test --- test/test_cost.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index b8956645028..6c38e6a8b65 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -2419,8 +2419,10 @@ def test_td3_notensordict( loss_val = loss(**kwargs) for i in loss_val: assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" - # for i, key in enumerate(loss_val_td.keys()): - # torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + + for i, key in enumerate(loss.out_keys): + torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + # test select loss.select_out_keys("loss_actor", "loss_qvalue") torch.manual_seed(0)