From 03572fc286f1199e97d0f23a09ce0c237fda51ba Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Sep 2023 12:03:20 +0200 Subject: [PATCH 1/5] update --- examples/decision_transformer/odt_config.yaml | 8 +-- examples/decision_transformer/online_dt.py | 55 +++++++++++++------ examples/decision_transformer/utils.py | 13 ++++- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index de8d5ffb6af..c332e024ccb 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -1,4 +1,4 @@ -# Task and env +# environment and task env: name: HalfCheetah-v3 task: "" @@ -10,7 +10,6 @@ env: num_train_envs: 1 num_eval_envs: 10 reward_scaling: 0.001 # for r2g - noop: 1 seed: 42 target_return_mode: reduce eval_target_return: 6000 @@ -26,7 +25,7 @@ logger: fintune_log_interval: 1 eval_steps: 1000 -# Buffer +# replay buffer replay_buffer: dataset: halfcheetah-medium-v2 batch_size: 256 @@ -38,7 +37,7 @@ replay_buffer: device: cuda:0 prefetch: 3 -# Optimization +# optimizer optim: device: cuda:0 lr: 1.0e-4 @@ -55,6 +54,7 @@ loss: alpha_init: 0.1 target_entropy: auto +# transformer model transformer: n_embd: 512 n_layer: 4 diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 01ab12dfabd..f4b8955865b 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -7,16 +7,19 @@ The helper functions are coded in the utils.py associated with this script. """ +import time + import hydra +import numpy as np import torch import tqdm - from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from utils import ( + log_metrics, make_env, make_logger, make_odt_loss, @@ -31,19 +34,34 @@ def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create logger logger = make_logger(cfg) + + # Create offline replay buffer offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling ) + + # Create test environment test_env = make_env(cfg.env, obs_loc, obs_std) + # Create policy model actor = make_odt_model(cfg) policy = actor.to(model_device) + # Create loss loss_module = make_odt_loss(cfg.loss, policy) + + # Create optimizer transformer_optim, temperature_optim, scheduler = make_odt_optimizer( cfg.optim, loss_module ) + + # Create inference policy inference_policy = DecisionTransformerInferenceWrapper( policy=policy, inference_context=cfg.env.inference_context, @@ -51,8 +69,6 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) - r0 = None - l0 = None pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps clip_grad = cfg.optim.clip_grad eval_steps = cfg.logger.eval_steps @@ -61,10 +77,12 @@ def main(cfg: "DictConfig"): # noqa: F821 print(" ***Pretraining*** ") # Pretraining + start_time = time.time() for i in range(pretrain_gradient_steps): pbar.update(i) + # sample data data = offline_buffer.sample() - # loss + # compute loss loss_vals = loss_module(data.to(model_device)) transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] temperature_loss = loss_vals["loss_alpha"] @@ -80,6 +98,13 @@ def main(cfg: "DictConfig"): # noqa: F821 scheduler.step() + # log metrics + to_log = { + "loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), + "loss_entropy": loss_vals["loss_entropy"].item(), + "loss_alpha": loss_vals["loss_alpha"].item(), + } + # evaluation with torch.no_grad(), set_exploration_type(ExplorationType.MODE): inference_policy.eval() @@ -91,20 +116,14 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=False, ) inference_policy.train() - if r0 is None: - r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling - if l0 is None: - l0 = transformer_loss.item() - - eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling - if logger is not None: - for key, value in loss_vals.items(): - logger.log_scalar(key, value.item(), i) - logger.log_scalar("evaluation reward", eval_reward, i) - - pbar.set_description( - f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - ) + to_log["evaluation_reward"] = ( + eval_td["next", "reward"].sum(1).mean().item() / reward_scaling + ) + + log_metrics(logger, to_log, i) + + pbar.close() + print(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 768237178c9..7d0600d4fc1 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -19,7 +19,6 @@ DoubleToFloat, EnvCreator, ExcludeTransform, - NoopResetEnv, ObservationNorm, RandomCropTensorDict, Reward2GoTransform, @@ -65,8 +64,6 @@ def make_base_env(env_cfg): env_task = env_cfg.task env_kwargs.update({"task_name": env_task}) env = env_library(**env_kwargs) - if env_cfg.noop > 1: - env = TransformedEnv(env, NoopResetEnv(env_cfg.noop)) return env @@ -472,3 +469,13 @@ def make_logger(cfg): wandb_kwargs={"config": cfg}, ) return logger + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) From e6d974cb7249faf9fa67f09cd9af1fc616418f19 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Sep 2023 12:16:28 +0200 Subject: [PATCH 2/5] fixes --- examples/decision_transformer/dt.py | 60 ++++++++++++++-------- examples/decision_transformer/online_dt.py | 8 +-- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 30e19608cf7..2766bbdf4eb 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -6,15 +6,19 @@ This is a self-contained example of an offline Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ +import time import hydra +import numpy as np import torch import tqdm +from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from utils import ( + log_metrics, make_dt_loss, make_dt_model, make_dt_optimizer, @@ -24,19 +28,37 @@ ) +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden @hydra.main(config_path=".", config_name="dt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create logger logger = make_logger(cfg) + + # Create offline replay buffer offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling ) + + # Create test environment test_env = make_env(cfg.env, obs_loc, obs_std) + + # Create policy model actor = make_dt_model(cfg) policy = actor.to(model_device) + # Create loss loss_module = make_dt_loss(cfg.loss, actor) + + # Create optimizer transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) + + # Create inference policy inference_policy = DecisionTransformerInferenceWrapper( policy=policy, inference_context=cfg.env.inference_context, @@ -44,9 +66,6 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) - r0 = None - l0 = None - pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps clip_grad = cfg.optim.clip_grad eval_steps = cfg.logger.eval_steps @@ -55,12 +74,14 @@ def main(cfg: "DictConfig"): # noqa: F821 print(" ***Pretraining*** ") # Pretraining + start_time = time.time() for i in range(pretrain_gradient_steps): pbar.update(i) + + # Sample data data = offline_buffer.sample() - # loss + # Compute loss loss_vals = loss_module(data.to(model_device)) - # backprop transformer_loss = loss_vals["loss"] transformer_optim.zero_grad() @@ -70,28 +91,25 @@ def main(cfg: "DictConfig"): # noqa: F821 scheduler.step() - # evaluation - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + # Log metrics + to_log = {"loss": loss_vals["loss"]} + + # Evaluation + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, policy=inference_policy, auto_cast_to_device=True, ) - if r0 is None: - r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling - if l0 is None: - l0 = transformer_loss.item() - - eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling - if logger is not None: - for key, value in loss_vals.items(): - logger.log_scalar(key, value.item(), i) - logger.log_scalar("evaluation reward", eval_reward, i) - - pbar.set_description( - f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - ) + to_log["evaluation_reward"] = ( + eval_td["next", "reward"].sum(1).mean().item() / reward_scaling + ) + + log_metrics(logger, to_log, i) + + pbar.close() + print(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index f4b8955865b..94bcd553076 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -80,9 +80,9 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = time.time() for i in range(pretrain_gradient_steps): pbar.update(i) - # sample data + # Sample data data = offline_buffer.sample() - # compute loss + # Compute loss loss_vals = loss_module(data.to(model_device)) transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] temperature_loss = loss_vals["loss_alpha"] @@ -98,14 +98,14 @@ def main(cfg: "DictConfig"): # noqa: F821 scheduler.step() - # log metrics + # Log metrics to_log = { "loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), "loss_entropy": loss_vals["loss_entropy"].item(), "loss_alpha": loss_vals["loss_alpha"].item(), } - # evaluation + # Evaluation with torch.no_grad(), set_exploration_type(ExplorationType.MODE): inference_policy.eval() if i % pretrain_log_interval == 0: From 7063b51609fef5963293e41e5fc32df6c0b221c3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 26 Sep 2023 08:52:01 +0200 Subject: [PATCH 3/5] update config --- examples/decision_transformer/dt_config.yaml | 10 +++++----- examples/decision_transformer/odt_config.yaml | 1 - 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index 69ced6be5d8..3514cf2203a 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -1,4 +1,4 @@ -# Task and env +# environment and task env: name: HalfCheetah-v3 task: "" @@ -25,7 +25,7 @@ logger: fintune_log_interval: 1 eval_steps: 1000 -# Buffer +# replay buffer replay_buffer: dataset: halfcheetah-medium-v2 batch_size: 64 @@ -37,13 +37,12 @@ replay_buffer: device: cpu prefetch: 3 -# Optimization +# optimization optim: device: cuda:0 lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 64 - lr_scheduler: "" pretrain_gradient_steps: 55000 updates_per_episode: 300 warmup_steps: 10000 @@ -52,7 +51,8 @@ optim: # loss loss: loss_function: "l2" - + +# transformer model transformer: n_embd: 128 n_layer: 3 diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index c332e024ccb..f8aebd30091 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -43,7 +43,6 @@ optim: lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 256 - lr_scheduler: "" pretrain_gradient_steps: 10000 updates_per_episode: 300 warmup_steps: 10000 From 911e74e40aff992828993742cba945003cc75677 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 4 Oct 2023 11:53:58 +0200 Subject: [PATCH 4/5] fix train eval logging --- examples/decision_transformer/dt.py | 4 ++-- examples/decision_transformer/online_dt.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 2766bbdf4eb..6b1154d7c7c 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -92,7 +92,7 @@ def main(cfg: "DictConfig"): # noqa: F821 scheduler.step() # Log metrics - to_log = {"loss": loss_vals["loss"]} + to_log = {"train/loss": loss_vals["loss"]} # Evaluation with set_exploration_type(ExplorationType.MODE), torch.no_grad(): @@ -102,7 +102,7 @@ def main(cfg: "DictConfig"): # noqa: F821 policy=inference_policy, auto_cast_to_device=True, ) - to_log["evaluation_reward"] = ( + to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 94bcd553076..2f646edc7ef 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -100,9 +100,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Log metrics to_log = { - "loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), - "loss_entropy": loss_vals["loss_entropy"].item(), - "loss_alpha": loss_vals["loss_alpha"].item(), + "train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), + "train/loss_entropy": loss_vals["loss_entropy"].item(), + "train/loss_alpha": loss_vals["loss_alpha"].item(), + "train/alpha": loss_vals["alpha"].item(), + "train/entropy": loss_vals["entropy"].item(), } # Evaluation @@ -116,7 +118,7 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=False, ) inference_policy.train() - to_log["evaluation_reward"] = ( + to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) From 2578d8fc228b4518fefc85b732a7ae4b563474e9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Oct 2023 08:56:32 +0100 Subject: [PATCH 5/5] amend --- examples/decision_transformer/dt.py | 4 ++-- examples/decision_transformer/online_dt.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 6b1154d7c7c..f241ce4e975 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -105,8 +105,8 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - - log_metrics(logger, to_log, i) + if logger is not None: + log_metrics(logger, to_log, i) pbar.close() print(f"Training time: {time.time() - start_time}") diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 2f646edc7ef..131320e9e21 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -122,7 +122,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - log_metrics(logger, to_log, i) + if logger is not None: + log_metrics(logger, to_log, i) pbar.close() print(f"Training time: {time.time() - start_time}")