Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions examples/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
This is a self-contained example of an offline Decision Transformer training script.
The helper functions are coded in the utils.py associated with this script.
"""
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper

from utils import (
log_metrics,
make_dt_loss,
make_dt_model,
make_dt_optimizer,
Expand All @@ -24,29 +28,44 @@
)


@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden
@hydra.main(config_path=".", config_name="dt_config")
def main(cfg: "DictConfig"): # noqa: F821
model_device = cfg.optim.device

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create logger
logger = make_logger(cfg)

# Create offline replay buffer
offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
cfg.replay_buffer, cfg.env.reward_scaling
)

# Create test environment
test_env = make_env(cfg.env, obs_loc, obs_std)

# Create policy model
actor = make_dt_model(cfg)
policy = actor.to(model_device)

# Create loss
loss_module = make_dt_loss(cfg.loss, actor)

# Create optimizer
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

r0 = None
l0 = None

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.logger.eval_steps
Expand All @@ -55,12 +74,14 @@ def main(cfg: "DictConfig"): # noqa: F821

print(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(i)

# Sample data
data = offline_buffer.sample()
# loss
# Compute loss
loss_vals = loss_module(data.to(model_device))
# backprop
transformer_loss = loss_vals["loss"]

transformer_optim.zero_grad()
Expand All @@ -70,28 +91,25 @@ def main(cfg: "DictConfig"): # noqa: F821

scheduler.step()

# evaluation
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
# Log metrics
to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
policy=inference_policy,
auto_cast_to_device=True,
)
if r0 is None:
r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
if l0 is None:
l0 = transformer_loss.item()

eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if logger is not None:
for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), i)
logger.log_scalar("evaluation reward", eval_reward, i)
log_metrics(logger, to_log, i)

pbar.set_description(
f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
)
pbar.close()
print(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions examples/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Task and env
# environment and task
env:
name: HalfCheetah-v3
task: ""
Expand All @@ -25,7 +25,7 @@ logger:
fintune_log_interval: 1
eval_steps: 1000

# Buffer
# replay buffer
replay_buffer:
dataset: halfcheetah-medium-v2
batch_size: 64
Expand All @@ -37,13 +37,12 @@ replay_buffer:
device: cpu
prefetch: 3

# Optimization
# optimization
optim:
device: cuda:0
lr: 1.0e-4
weight_decay: 5.0e-4
batch_size: 64
lr_scheduler: ""
pretrain_gradient_steps: 55000
updates_per_episode: 300
warmup_steps: 10000
Expand All @@ -52,7 +51,8 @@ optim:
# loss
loss:
loss_function: "l2"


# transformer model
transformer:
n_embd: 128
n_layer: 3
Expand Down
9 changes: 4 additions & 5 deletions examples/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Task and env
# environment and task
env:
name: HalfCheetah-v3
task: ""
Expand All @@ -10,7 +10,6 @@ env:
num_train_envs: 1
num_eval_envs: 10
reward_scaling: 0.001 # for r2g
noop: 1
seed: 42
target_return_mode: reduce
eval_target_return: 6000
Expand All @@ -26,7 +25,7 @@ logger:
fintune_log_interval: 1
eval_steps: 1000

# Buffer
# replay buffer
replay_buffer:
dataset: halfcheetah-medium-v2
batch_size: 256
Expand All @@ -38,13 +37,12 @@ replay_buffer:
device: cuda:0
prefetch: 3

# Optimization
# optimizer
optim:
device: cuda:0
lr: 1.0e-4
weight_decay: 5.0e-4
batch_size: 256
lr_scheduler: ""
pretrain_gradient_steps: 10000
updates_per_episode: 300
warmup_steps: 10000
Expand All @@ -55,6 +53,7 @@ loss:
alpha_init: 0.1
target_entropy: auto

# transformer model
transformer:
n_embd: 512
n_layer: 4
Expand Down
54 changes: 38 additions & 16 deletions examples/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
The helper functions are coded in the utils.py associated with this script.
"""

import time

import hydra
import numpy as np
import torch
import tqdm

from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper

from utils import (
log_metrics,
make_env,
make_logger,
make_odt_loss,
Expand All @@ -31,28 +34,41 @@
def main(cfg: "DictConfig"): # noqa: F821
model_device = cfg.optim.device

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create logger
logger = make_logger(cfg)

# Create offline replay buffer
offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
cfg.replay_buffer, cfg.env.reward_scaling
)

# Create test environment
test_env = make_env(cfg.env, obs_loc, obs_std)

# Create policy model
actor = make_odt_model(cfg)
policy = actor.to(model_device)

# Create loss
loss_module = make_odt_loss(cfg.loss, policy)

# Create optimizer
transformer_optim, temperature_optim, scheduler = make_odt_optimizer(
cfg.optim, loss_module
)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

r0 = None
l0 = None
pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.logger.eval_steps
Expand All @@ -61,10 +77,12 @@ def main(cfg: "DictConfig"): # noqa: F821

print(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(i)
# Sample data
data = offline_buffer.sample()
# loss
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]
Expand All @@ -80,7 +98,16 @@ def main(cfg: "DictConfig"): # noqa: F821

scheduler.step()

# evaluation
# Log metrics
to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(),
"train/loss_entropy": loss_vals["loss_entropy"].item(),
"train/loss_alpha": loss_vals["loss_alpha"].item(),
"train/alpha": loss_vals["alpha"].item(),
"train/entropy": loss_vals["entropy"].item(),
}

# Evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
inference_policy.eval()
if i % pretrain_log_interval == 0:
Expand All @@ -91,20 +118,15 @@ def main(cfg: "DictConfig"): # noqa: F821
break_when_any_done=False,
)
inference_policy.train()
if r0 is None:
r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
if l0 is None:
l0 = transformer_loss.item()
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
if logger is not None:
for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), i)
logger.log_scalar("evaluation reward", eval_reward, i)
log_metrics(logger, to_log, i)

pbar.set_description(
f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
)
pbar.close()
print(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
13 changes: 10 additions & 3 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
DoubleToFloat,
EnvCreator,
ExcludeTransform,
NoopResetEnv,
ObservationNorm,
RandomCropTensorDict,
Reward2GoTransform,
Expand Down Expand Up @@ -65,8 +64,6 @@ def make_base_env(env_cfg):
env_task = env_cfg.task
env_kwargs.update({"task_name": env_task})
env = env_library(**env_kwargs)
if env_cfg.noop > 1:
env = TransformedEnv(env, NoopResetEnv(env_cfg.noop))
return env


Expand Down Expand Up @@ -472,3 +469,13 @@ def make_logger(cfg):
wandb_kwargs={"config": cfg},
)
return logger


# ====================================================================
# General utils
# ---------


def log_metrics(logger, metrics, step):
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step)