From 3e7319914bd7ed3cc834e1cc55685e68b2d8e9c6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 13 Mar 2023 18:16:20 +0000 Subject: [PATCH 01/12] init --- examples/ddpg/config.yaml | 100 +++++--- examples/ddpg/ddpg2.py | 378 +++++++++++++++++++++++++++++++ torchrl/trainers/helpers/envs.py | 26 +-- 3 files changed, 454 insertions(+), 50 deletions(-) create mode 100644 examples/ddpg/ddpg2.py diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 5ad3912c0ef..5efabe1302b 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -1,36 +1,64 @@ -env_name: HalfCheetah-v4 -env_task: "" -env_library: gym -async_collection: 1 -record_video: 0 -normalize_rewards_online: 1 -normalize_rewards_online_scale: 5 -frame_skip: 1 -frames_per_batch: 1024 -optim_steps_per_batch: 128 -batch_size: 256 -total_frames: 1000000 -prb: 1 -lr: 3e-4 -ou_exploration: 1 -multi_step: 1 -init_random_frames: 25000 -activation: elu -gSDE: 0 -from_pixels: 0 -#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] -collector_devices: [cpu,cpu,cpu,cpu] -env_per_collector: 8 -num_workers: 32 -lr_scheduler: "" -value_network_update_interval: 200 -record_interval: 10 -max_frames_per_traj: -1 -weight_decay: 0.0 -annealing_frames: 1000000 -init_env_steps: 10000 -record_frames: 10000 -loss_function: smooth_l1 -batch_transform: 1 -buffer_prefetch: 64 -norm_stats: 1 +# task and env +env: + env_name: HalfCheetah-v4 + env_task: "" + env_library: gym + normalize_rewards_online: 1 + normalize_rewards_online_scale: 5 + frame_skip: 1 + norm_stats: 1 + num_envs: 4 + n_samples_stats: 1000 + noop: 1 + reward_scaling: + +# collector +collector: + async_collection: 1 + frames_per_batch: 1024 + total_frames: 1000000 + multi_step: 3 # 0 to disable + init_random_frames: 25000 + collector_devices: cpu # [cpu,cpu,cpu,cpu] + num_collectors: 4 + max_frames_per_traj: -1 + +# eval +recorder: + record_video: True + record_interval: 10 + record_frames: 10000 + +# logger +logger: + logger_class: wandb + exp_name: ddpg_cheetah_gym + +# Buffer +replay_buffer: + prb: 1 + buffer_prefetch: 64 + capacity: 1_000_000 + +# Optim +optim: + device: cpu + lr: 3e-4 + weight_decay: 0.0 + batch_size: 256 + lr_scheduler: "" + value_network_update_interval: 200 + optim_steps_per_batch: 8 + +# Policy and model +model: + ou_exploration: 1 + annealing_frames: 1000000 + noisy: False + activation: elu + +# loss +loss: + loss_function: smooth_l1 + gamma: 0.99 + tau: 0.05 diff --git a/examples/ddpg/ddpg2.py b/examples/ddpg/ddpg2.py new file mode 100644 index 00000000000..04fa8c640b6 --- /dev/null +++ b/examples/ddpg/ddpg2.py @@ -0,0 +1,378 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +from copy import deepcopy + +import hydra +import torch.cuda +import tqdm +from hydra.core.config_store import ConfigStore +from tensordict.nn import TensorDictModule +from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector +from torchrl.data import ( + CompositeSpec, + LazyMemmapStorage, + MultiStep, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + NoopResetEnv, + ObservationNorm, + ParallelEnv, +) +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import ( + AdditiveGaussianWrapper, + DdpgMlpActor, + DdpgMlpQNet, + NoisyLinear, + OrnsteinUhlenbeckProcessWrapper, + ProbabilisticActor, + SafeModule, + TanhDelta, + ValueOperator, +) +from torchrl.objectives import DDPGLoss, SoftUpdate +from torchrl.record import VideoRecorder +from torchrl.record.loggers import generate_exp_name, get_logger, WandbLogger +from torchrl.trainers.helpers.collectors import ( + make_collector_offpolicy, + OffPolicyCollectorConfig, +) +from torchrl.trainers.helpers.envs import ( + correct_for_frame_skip, + EnvConfig, + initialize_observation_norm_transforms, + LIBS, + parallel_env_constructor, + retrieve_observation_norms_state_dict, + transformed_env_constructor, +) +from torchrl.trainers.helpers.logger import LoggerConfig +from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss +from torchrl.trainers.helpers.models import ( + ACTIVATIONS, + DDPGModelConfig, + make_ddpg_actor, +) +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig +from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig + + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + + +def make_base_env(env_cfg, from_pixels=False): + env_library = LIBS[env_cfg.env_library] + env_name = env_cfg.env_name + frame_skip = env_cfg.frame_skip + + env_kwargs = { + "env_name": env_name, + "frame_skip": frame_skip, + "from_pixels": from_pixels, # for rendering + "pixels_only": False, + } + if env_library is DMControlEnv: + env_task = env_cfg.env_task + env_kwargs.update({"task_name": env_task}) + env = env_library(**env_kwargs) + if env_cfg.noop > 1: + env = TransformedEnv(env, NoopResetEnv(env_cfg.noop)) + return env + + +def make_transformed_env(base_env, env_cfg): + if not isinstance(env_cfg.reward_scaling, float): + env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) + + env_library = LIBS[env_cfg.env_library] + env = TransformedEnv(base_env) + + reward_scaling = env_cfg.reward_scaling + + env.append_transform(RewardScaling(0.0, reward_scaling)) + + double_to_float_list = [] + double_to_float_inv_list = [] + + # we concatenate all the state vectors + # even if there is a single tensor, it'll be renamed in "observation_vector" + selected_keys = [ + key for key in env.observation_spec.keys(True, True) if key != "pixels" + ] + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + obs_norm = ObservationNorm(in_keys=[out_key]) + env.append_transform(obs_norm) + + if env_library is DMControlEnv: + double_to_float_list += [ + "reward", + ] + double_to_float_list += [ + "action", + ] + double_to_float_inv_list += ["action"] # DMControl requires double-precision + double_to_float_list += ["observation_vector"] + else: + double_to_float_list += ["observation_vector"] + env.append_transform( + DoubleToFloat( + in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list + ) + ) + return env + + +def make_parallel_env(env_cfg, state_dict): + num_envs = env_cfg.num_envs + env = make_transformed_env( + ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg + ) + for t in env.transform: + if isinstance(t, ObservationNorm): + t.init_stats(3, cat_dim=1, reduce_dim=[0, 1]) + env.load_state_dict(state_dict) + return env + + +def make_collector(cfg, state_dict, policy): + env_cfg = cfg.env + loss_cfg = cfg.loss + collector_cfg = cfg.collector + if collector_cfg.async_collection: + collector_class = MultiaSyncDataCollector + else: + collector_class = MultiSyncDataCollector + if collector_cfg.multi_step: + ms = MultiStep(gamma=loss_cfg.gamma, n_steps=collector_cfg.multi_step) + else: + ms = None + collector = collector_class( + [make_parallel_env(env_cfg, state_dict=state_dict)] + * collector_cfg.num_collectors, + policy, + frames_per_batch=collector_cfg.frames_per_batch, + total_frames=collector_cfg.total_frames, + postproc=ms, + devices=collector_cfg.collector_devices, + init_random_frames=collector_cfg.init_random_frames, + max_frames_per_traj=collector_cfg.max_frames_per_traj, + ) + return collector + + +def make_logger(logger_cfg): + if logger_cfg.logger_class == "wandb": + logger = WandbLogger(logger_cfg.exp_name) + else: + raise NotImplementedError + return logger + + +def make_recorder(cfg, logger): + env_cfg = deepcopy(cfg.env) + env = make_transformed_env(make_base_env(env_cfg, from_pixels=True), env_cfg) + env.insert_transform( + 0, VideoRecorder(logger=logger, tag=cfg.logger.exp_name, in_keys=["pixels"]) + ) + + +def make_replay_buffer(rb_cfg): + if rb_cfg.prb: + sampler = PrioritizedSampler(max_capacity=rb_cfg.capacity, alpha=0.7, beta=0.5) + else: + sampler = RandomSampler() + return TensorDictReplayBuffer( + storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler + ) + + +def make_ddpg_model(cfg): + + env_cfg = cfg.env + model_cfg = cfg.model + proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + + noisy = model_cfg.noisy + + linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear + + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net_default_kwargs = { + "action_dim": out_features, + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + }, + } + in_keys = ["observation_vector"] + actor_net = DdpgMlpActor(**actor_net_default_kwargs) + actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=["param"]) + + # We use a ProbabilisticActor to make sure that we map the + # network output to the right space using a TanhDelta + # distribution. + actor = ProbabilisticActor( + module=actor_module, + in_keys=["param"], + spec=CompositeSpec(action=env_specs["input_spec"]["action"]), + safe=True, + distribution_class=TanhDelta, + distribution_kwargs={ + "min": env_specs["input_spec"]["action"].space.minimum, + "max": env_specs["input_spec"]["action"].space.maximum, + }, + ) + + # Value model: DdpgMlpQNet is a specialized class that reads the state and + # the action and outputs a value from it. It has two sub-components that + # we parameterize with `mlp_net_kwargs_net1` and `mlp_net_kwargs_net2`. + state_class = ValueOperator + value_net_default_kwargs1 = { + "activation_class": ACTIVATIONS[model_cfg.activation], + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + "bias_last_layer": True, + } + value_net_default_kwargs2 = { + "num_cells": [400, 300], + "activation_class": ACTIVATIONS[model_cfg.activation], + "bias_last_layer": True, + "layer_class": linear_layer_class, + } + in_keys = ["observation_vector", "action"] + out_keys = ["state_action_value"] + q_net = DdpgMlpQNet( + mlp_net_kwargs_net1=value_net_default_kwargs1, + mlp_net_kwargs_net2=value_net_default_kwargs2, + ) + value = state_class( + in_keys=in_keys, + out_keys=out_keys, + module=q_net, + ) + + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + for t in proof_environment.transform: + if isinstance(t, ObservationNorm): + t.init_stats(2) + td = proof_environment.rollout(max_steps=1000) + print(td) + actor(td) + value(td) + + return actor, value + + +def make_policy(model_cfg, actor): + if model_cfg.ou_exploration: + return OrnsteinUhlenbeckProcessWrapper(actor) + else: + return AdditiveGaussianWrapper(actor) + + +def get_stats(env_cfg): + env = make_transformed_env(make_base_env(env_cfg), env_cfg) + for t in env.transform: + if isinstance(t, ObservationNorm): + t.init_stats(env_cfg.n_samples_stats) + return env.state_dict() + + +def make_loss(loss_cfg, actor_network, value_network): + loss = DDPGLoss( + actor_network, + value_network, + gamma=loss_cfg.gamma, + loss_function=loss_cfg.loss_function, + ) + target_net_updater = SoftUpdate(loss, 1 - loss_cfg.tau) + target_net_updater.init_() + return loss, target_net_updater + + +def make_optim(optim_cfg, actor_network, value_network): + optim = torch.optim.Adam( + list(actor_network.parameters()) + list(value_network.parameters()), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + return optim + + +@hydra.main(config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + + cfg = correct_for_frame_skip(cfg) + model_device = cfg.optim.device + + exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) + + state_dict = get_stats(cfg.env) + logger = make_logger(cfg.logger) + recorder = make_recorder(cfg, logger) + replay_buffer = make_replay_buffer(cfg.replay_buffer) + + actor_network, value_network = make_ddpg_model(cfg) + actor_network = actor_network.to(model_device) + value_network = value_network.to(model_device) + + policy = make_policy(cfg.model, actor_network) + collector = make_collector(cfg, state_dict=state_dict, policy=policy) + loss, target_net_updater = make_loss(cfg.loss, actor_network, value_network) + optim = make_optim(cfg.optim, actor_network, value_network) + + optim_steps_per_batch = cfg.optim.optim_steps_per_batch + batch_size = cfg.optim.batch_size + init_random_frames = cfg.collector.init_random_frames + + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + collected_frames = 0 + for i, data in enumerate(collector): + collected_frames += data.numel() + pbar.update(data.numel()) + # extend replay buffer + replay_buffer.extend(data.view(-1)) + if collected_frames >= init_random_frames: + for j in range(optim_steps_per_batch): + # sample + sample = replay_buffer.sample(batch_size) + # loss + loss_vals = loss(sample) + # backprop + loss_val = sum( + val for key, val in loss_vals.items() if key.startswith("loss") + ) + loss_val.backward() + optim.step() + optim.zero_grad() + target_net_updater.step() + pbar.set_description(f"loss: {loss_val.item(): 4.4f}") + collector.update_policy_weights_() + + +if __name__ == "__main__": + main() diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index fac07ee0afd..a2842973bfa 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -55,20 +55,18 @@ def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 """ # Adapt all frame counts wrt frame_skip - if cfg.frame_skip != 1: - fields = [ - "max_frames_per_traj", - "total_frames", - "frames_per_batch", - "record_frames", - "annealing_frames", - "init_random_frames", - "init_env_steps", - "noops", - ] - for field in fields: - if hasattr(cfg, field): - setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip) + + frame_skip = cfg.env.frame_skip + + if frame_skip != 1: + cfg.collector.max_frames_per_traj //= frame_skip + cfg.collector.total_frames //= frame_skip + cfg.collector.frames_per_batch //= frame_skip + cfg.collector.init_random_frames //= frame_skip + cfg.collector.init_env_steps //= frame_skip + cfg.recorder.record_frames //= frame_skip + cfg.model.annealing_frames //= frame_skip + cfg.env.noops //= frame_skip return cfg From 671ad0a210b91c09a51da53154273fd4f0f23be8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 20 Mar 2023 11:41:27 +0000 Subject: [PATCH 02/12] amend --- .../linux_examples/scripts/run_test.sh | 87 +++-- examples/ddpg/config.yaml | 7 +- examples/ddpg/ddpg.py | 274 +++++-------- examples/ddpg/{ddpg2.py => utils.py} | 362 +++++++++++------- torchrl/trainers/trainers.py | 2 +- 5 files changed, 377 insertions(+), 355 deletions(-) rename examples/ddpg/{ddpg2.py => utils.py} (61%) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index cc57b730be8..4e932c12e78 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -27,19 +27,71 @@ export MKL_THREADING_LAYER=GNU python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 20 python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20 +# ======================================================================================== +# DDPG +# ---- +# +# Modalities: +# ^^^^^^^^^^^ +# +# pixels on/off +# Batched on/off +# # With batched environments python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=4 \ - env_per_collector=2 \ - collector_devices=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=2 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + record.video=True \ + record.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=True +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=2 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + record.video=True \ + record.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=False +# With single envs +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=1 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + record.video=True \ + record.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=True +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=1 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + record.video=True \ + record.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=False + python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ total_frames=48 \ batch_size=10 \ @@ -112,19 +164,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/drea buffer_size=120 \ rssm_hidden_dim=17 -# With single envs -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=2 \ - env_per_collector=1 \ - collector_devices=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ total_frames=48 \ batch_size=10 \ diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 5efabe1302b..351fae3e24d 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -11,6 +11,7 @@ env: n_samples_stats: 1000 noop: 1 reward_scaling: + from_pixels: False # collector collector: @@ -25,9 +26,9 @@ collector: # eval recorder: - record_video: True - record_interval: 10 - record_frames: 10000 + video: True + interval: 10000 # record interval in frames + frames: 10000 # logger logger: diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index aed849cd6b5..3e634bc479a 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -2,196 +2,104 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""DDPG Example. -import dataclasses +This is a self-contained example of a DDPG training script. + +It works across Gym and DM-control over a variety of tasks. + +Both state and pixel-based environments are supported. + +The helper functions are coded in the utils.py associated with this script. + +""" import hydra -import torch.cuda -from hydra.core.config_store import ConfigStore -from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv -from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import OrnsteinUhlenbeckProcessWrapper -from torchrl.record import VideoRecorder -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.collectors import ( - make_collector_offpolicy, - OffPolicyCollectorConfig, +import tqdm +from torchrl.trainers.helpers.envs import correct_for_frame_skip + +from utils import ( + get_stats, + make_collector, + make_ddpg_model, + make_logger, + make_loss, + make_optim, + make_policy, + make_recorder, + make_replay_buffer, ) -from torchrl.trainers.helpers.envs import ( - correct_for_frame_skip, - EnvConfig, - initialize_observation_norm_transforms, - parallel_env_constructor, - retrieve_observation_norms_state_dict, - transformed_env_constructor, -) -from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss -from torchrl.trainers.helpers.models import DDPGModelConfig, make_ddpg_actor -from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig -from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig - -config_fields = [ - (config_field.name, config_field.type, config_field) - for config_cls in ( - TrainerConfig, - OffPolicyCollectorConfig, - EnvConfig, - LossConfig, - DDPGModelConfig, - LoggerConfig, - ReplayArgsConfig, - ) - for config_field in dataclasses.fields(config_cls) -] -Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) -cs = ConfigStore.instance() -cs.store(name="config", node=Config) - -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - - -@hydra.main(version_base=None, config_path=".", config_name="config") + + +@hydra.main(config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 cfg = correct_for_frame_skip(cfg) - - if not isinstance(cfg.reward_scaling, float): - cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0) - - device = ( - torch.device("cpu") - if torch.cuda.device_count() == 0 - else torch.device("cuda:0") - ) - - exp_name = generate_exp_name("DDPG", cfg.exp_name) - logger = get_logger( - logger_type=cfg.logger, logger_name="ddpg_logging", experiment_name=exp_name - ) - video_tag = exp_name if cfg.record_video else "" - - key, init_env_steps, stats = None, None, None - if not cfg.vecnorm and cfg.norm_stats: - if not hasattr(cfg, "init_env_steps"): - raise AttributeError("init_env_steps missing from arguments.") - key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") - init_env_steps = cfg.init_env_steps - stats = {"loc": None, "scale": None} - elif cfg.from_pixels: - stats = {"loc": 0.5, "scale": 0.5} - - proof_env = transformed_env_constructor( - cfg=cfg, - stats=stats, - use_env_creator=False, - )() - initialize_observation_norm_transforms( - proof_environment=proof_env, num_iter=init_env_steps, key=key - ) - _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] - - model = make_ddpg_actor( - proof_env, - cfg=cfg, - device=device, - ) - loss_module, target_net_updater = make_ddpg_loss(model, cfg) - - actor_model_explore = model[0] - if cfg.ou_exploration: - if cfg.gSDE: - raise RuntimeError("gSDE and ou_exploration are incompatible") - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( - actor_model_explore, - annealing_num_steps=cfg.annealing_frames, - sigma=cfg.ou_sigma, - theta=cfg.ou_theta, - ).to(device) - if device == torch.device("cpu"): - # mostly for debugging - actor_model_explore.share_memory() - - if cfg.gSDE: - with torch.no_grad(), set_exploration_mode("random"): - # get dimensions to build the parallel env - proof_td = actor_model_explore(proof_env.reset().to(device)) - action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:] - del proof_td - else: - action_dim_gsde, state_dim_gsde = None, None - - proof_env.close() - - create_env_fn = parallel_env_constructor( - cfg=cfg, - obs_norm_state_dict=obs_norm_state_dict, - action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde, - ) - - collector = make_collector_offpolicy( - make_env=create_env_fn, - actor_model_explore=actor_model_explore, - cfg=cfg, - # make_env_kwargs=[ - # {"device": device} if device >= 0 else {} - # for device in args.env_rendering_devices - # ], - ) - - replay_buffer = make_replay_buffer(device, cfg) - - recorder = transformed_env_constructor( - cfg, - video_tag=video_tag, - norm_obs_only=True, - obs_norm_state_dict=obs_norm_state_dict, - logger=logger, - use_env_creator=False, - )() - if isinstance(create_env_fn, ParallelEnv): - raise NotImplementedError("This behaviour is deprecated") - elif isinstance(create_env_fn, EnvCreator): - recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict()) - elif isinstance(create_env_fn, TransformedEnv): - recorder.transform = create_env_fn.transform.clone() - else: - raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}") - if logger is not None and video_tag: - recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag)) - - # reset reward scaling - for t in recorder.transform: - if isinstance(t, RewardScaling): - t.scale.fill_(1.0) - t.loc.fill_(0.0) - - trainer = make_trainer( - collector, - loss_module, - recorder, - target_net_updater, - actor_model_explore, - replay_buffer, - logger, - cfg, - ) - - final_seed = collector.set_seed(cfg.seed) - print(f"init seed: {cfg.seed}, final seed: {final_seed}") - - trainer.train() - return (logger.log_dir, trainer._log_dict) + model_device = cfg.optim.device + + state_dict = get_stats(cfg.env) + logger = make_logger(cfg.logger) + replay_buffer = make_replay_buffer(cfg.replay_buffer) + + actor_network, value_network = make_ddpg_model(cfg) + actor_network = actor_network.to(model_device) + value_network = value_network.to(model_device) + + policy = make_policy(cfg.model, actor_network) + collector = make_collector(cfg, state_dict=state_dict, policy=policy) + loss, target_net_updater = make_loss(cfg.loss, actor_network, value_network) + optim = make_optim(cfg.optim, actor_network, value_network) + recorder = make_recorder(cfg, logger, policy) + + optim_steps_per_batch = cfg.optim.optim_steps_per_batch + batch_size = cfg.optim.batch_size + init_random_frames = cfg.collector.init_random_frames + record_interval = cfg.recorder.interval + + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + collected_frames = 0 + + r0 = None + l0 = None + for data in collector: + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + # extend replay buffer + replay_buffer.extend(data.view(-1)) + if collected_frames >= init_random_frames: + for _ in range(optim_steps_per_batch): + # sample + sample = replay_buffer.sample(batch_size) + # loss + loss_vals = loss(sample) + # backprop + loss_val = sum( + val for key, val in loss_vals.items() if key.startswith("loss") + ) + loss_val.backward() + optim.step() + optim.zero_grad() + target_net_updater.step() + if r0 is None: + r0 = data["reward"].mean().item() + if l0 is None: + l0 = loss_val.item() + + for key, value in loss_vals.item(): + logger.log_scalar(key, value.item(), collected_frames) + logger.log_scalar( + "reward_training", data["reward"].mean().item(), collected_frames + ) + + pbar.set_description( + f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['reward'].mean(): 4.4f} (init={r0: 4.4f})" + ) + collector.update_policy_weights_() + if ( + collected_frames - frames_in_batch + ) // record_interval < collected_frames // record_interval: + recorder() if __name__ == "__main__": diff --git a/examples/ddpg/ddpg2.py b/examples/ddpg/utils.py similarity index 61% rename from examples/ddpg/ddpg2.py rename to examples/ddpg/utils.py index 04fa8c640b6..ed1d0a56621 100644 --- a/examples/ddpg/ddpg2.py +++ b/examples/ddpg/utils.py @@ -1,16 +1,9 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import dataclasses from copy import deepcopy -import hydra -import torch.cuda -import tqdm -from hydra.core.config_store import ConfigStore +import torch.nn +import torch.optim from tensordict.nn import TensorDictModule + from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector from torchrl.data import ( CompositeSpec, @@ -20,52 +13,39 @@ ) from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.envs import ( + CatFrames, CatTensors, DoubleToFloat, EnvCreator, + GrayScale, NoopResetEnv, ObservationNorm, ParallelEnv, + Resize, + RewardScaling, + ToTensorImage, + TransformedEnv, ) from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import ( AdditiveGaussianWrapper, + DdpgCnnActor, + DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, NoisyLinear, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, - SafeModule, TanhDelta, ValueOperator, ) from torchrl.objectives import DDPGLoss, SoftUpdate from torchrl.record import VideoRecorder -from torchrl.record.loggers import generate_exp_name, get_logger, WandbLogger -from torchrl.trainers.helpers.collectors import ( - make_collector_offpolicy, - OffPolicyCollectorConfig, -) -from torchrl.trainers.helpers.envs import ( - correct_for_frame_skip, - EnvConfig, - initialize_observation_norm_transforms, - LIBS, - parallel_env_constructor, - retrieve_observation_norms_state_dict, - transformed_env_constructor, -) -from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss -from torchrl.trainers.helpers.models import ( - ACTIVATIONS, - DDPGModelConfig, - make_ddpg_actor, -) -from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig -from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig +from torchrl.record.loggers import generate_exp_name, WandbLogger +from torchrl.trainers import Recorder +from torchrl.trainers.helpers.envs import LIBS +from torchrl.trainers.helpers.models import ACTIVATIONS DEFAULT_REWARD_SCALING = { @@ -78,11 +58,17 @@ "humanoid": 100, } +# ==================================================================== +# Environment utils +# ----------------- -def make_base_env(env_cfg, from_pixels=False): + +def make_base_env(env_cfg, from_pixels=None): env_library = LIBS[env_cfg.env_library] env_name = env_cfg.env_name frame_skip = env_cfg.frame_skip + if from_pixels is None: + from_pixels = env_cfg.from_pixels env_kwargs = { "env_name": env_name, @@ -100,6 +86,56 @@ def make_base_env(env_cfg, from_pixels=False): def make_transformed_env(base_env, env_cfg): + from_pixels = env_cfg.from_pixels + if from_pixels: + return make_transformed_env_pixels(base_env, env_cfg) + else: + return make_transformed_env_states(base_env, env_cfg) + + +def make_transformed_env_pixels(base_env, env_cfg): + if not isinstance(env_cfg.reward_scaling, float): + env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) + + env_library = LIBS[env_cfg.env_library] + env = TransformedEnv(base_env) + + reward_scaling = env_cfg.reward_scaling + + env.append_transform(RewardScaling(0.0, reward_scaling)) + + double_to_float_list = [] + double_to_float_inv_list = [] + + # + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + + obs_norm = ObservationNorm(in_keys=["pixels"]) + env.append_transform(obs_norm) + + if env_library is DMControlEnv: + double_to_float_list += [ + "reward", + ] + double_to_float_list += [ + "action", + ] + double_to_float_inv_list += ["action"] # DMControl requires double-precision + double_to_float_list += ["observation_vector"] + else: + double_to_float_list += ["observation_vector"] + env.append_transform( + DoubleToFloat( + in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list + ) + ) + return env + + +def make_transformed_env_states(base_env, env_cfg): if not isinstance(env_cfg.reward_scaling, float): env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) @@ -155,6 +191,29 @@ def make_parallel_env(env_cfg, state_dict): return env +def get_stats(env_cfg): + from_pixels = env_cfg.from_pixels + env = make_transformed_env(make_base_env(env_cfg), env_cfg) + for t in env.transform: + if isinstance(t, ObservationNorm): + if from_pixels: + t.init_stats( + env_cfg.n_samples_stats, + cat_dim=-3, + reduce_dim=(-1, -2, -3), + keep_dims=(-1, -2, -3), + ) + else: + t.init_stats(env_cfg.n_samples_stats) + + return env.state_dict() + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + def make_collector(cfg, state_dict, policy): env_cfg = cfg.env loss_cfg = cfg.loss @@ -181,22 +240,6 @@ def make_collector(cfg, state_dict, policy): return collector -def make_logger(logger_cfg): - if logger_cfg.logger_class == "wandb": - logger = WandbLogger(logger_cfg.exp_name) - else: - raise NotImplementedError - return logger - - -def make_recorder(cfg, logger): - env_cfg = deepcopy(cfg.env) - env = make_transformed_env(make_base_env(env_cfg, from_pixels=True), env_cfg) - env.insert_transform( - 0, VideoRecorder(logger=logger, tag=cfg.logger.exp_name, in_keys=["pixels"]) - ) - - def make_replay_buffer(rb_cfg): if rb_cfg.prb: sampler = PrioritizedSampler(max_capacity=rb_cfg.capacity, alpha=0.7, beta=0.5) @@ -207,29 +250,34 @@ def make_replay_buffer(rb_cfg): ) +# ==================================================================== +# Model +# ----- +# +# We give one version of the model for learning from pixels, and one for state. +# TorchRL comes in handy at this point, as the high-level interactions with +# these models is unchanged, regardless of the modality. +# + + def make_ddpg_model(cfg): env_cfg = cfg.env model_cfg = cfg.model proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) - - noisy = model_cfg.noisy - - linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear - env_specs = proof_environment.specs - out_features = env_specs["input_spec"]["action"].shape[0] + from_pixels = env_cfg.from_pixels - actor_net_default_kwargs = { - "action_dim": out_features, - "mlp_net_kwargs": { - "layer_class": linear_layer_class, - "activation_class": ACTIVATIONS[model_cfg.activation], - }, - } - in_keys = ["observation_vector"] - actor_net = DdpgMlpActor(**actor_net_default_kwargs) - actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=["param"]) + if not from_pixels: + actor_net, q_net = make_ddpg_modules_state(model_cfg, proof_environment) + in_keys = ["observation_vector"] + out_keys = ["param"] + else: + actor_net, q_net = make_ddpg_modules_pixels(model_cfg, proof_environment) + in_keys = ["pixels"] + out_keys = ["param", "hidden"] + + actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) # We use a ProbabilisticActor to make sure that we map the # network output to the right space using a TanhDelta @@ -246,12 +294,53 @@ def make_ddpg_model(cfg): }, ) + if not from_pixels: + in_keys = ["observation_vector", "action"] + else: + in_keys = ["pixels", "action"] + + out_keys = ["state_action_value"] + value = ValueOperator( + in_keys=in_keys, + out_keys=out_keys, + module=q_net, + ) + + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + for t in proof_environment.transform: + if isinstance(t, ObservationNorm): + t.init_stats(2) + td = proof_environment.rollout(max_steps=1000) + print(td) + actor(td) + value(td) + + return actor, value + + +def make_ddpg_modules_state(model_cfg, proof_environment): + + noisy = model_cfg.noisy + + linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear + + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net_default_kwargs = { + "action_dim": out_features, + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + }, + } + actor_net = DdpgMlpActor(**actor_net_default_kwargs) + # Value model: DdpgMlpQNet is a specialized class that reads the state and # the action and outputs a value from it. It has two sub-components that # we parameterize with `mlp_net_kwargs_net1` and `mlp_net_kwargs_net2`. - state_class = ValueOperator value_net_default_kwargs1 = { - "activation_class": ACTIVATIONS[model_cfg.activation], "layer_class": linear_layer_class, "activation_class": ACTIVATIONS[model_cfg.activation], "bias_last_layer": True, @@ -262,29 +351,40 @@ def make_ddpg_model(cfg): "bias_last_layer": True, "layer_class": linear_layer_class, } - in_keys = ["observation_vector", "action"] - out_keys = ["state_action_value"] q_net = DdpgMlpQNet( mlp_net_kwargs_net1=value_net_default_kwargs1, mlp_net_kwargs_net2=value_net_default_kwargs2, ) - value = state_class( - in_keys=in_keys, - out_keys=out_keys, - module=q_net, - ) + return actor_net, q_net - # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): - for t in proof_environment.transform: - if isinstance(t, ObservationNorm): - t.init_stats(2) - td = proof_environment.rollout(max_steps=1000) - print(td) - actor(td) - value(td) - return actor, value +def make_ddpg_modules_pixels(model_cfg, proof_environment): + noisy = model_cfg.noisy + + linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear + + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net_default_kwargs = { + "action_dim": out_features, + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + }, + "conv_net_kwargs": {"activation_class": ACTIVATIONS[model_cfg.activation]}, + } + actor_net = DdpgCnnActor(**actor_net_default_kwargs) + + value_net_default_kwargs = { + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + } + q_net = DdpgCnnQNet(**value_net_default_kwargs) + + return actor_net, q_net def make_policy(model_cfg, actor): @@ -294,12 +394,9 @@ def make_policy(model_cfg, actor): return AdditiveGaussianWrapper(actor) -def get_stats(env_cfg): - env = make_transformed_env(make_base_env(env_cfg), env_cfg) - for t in env.transform: - if isinstance(t, ObservationNorm): - t.init_stats(env_cfg.n_samples_stats) - return env.state_dict() +# ==================================================================== +# DDPG Loss +# --------- def make_loss(loss_cfg, actor_network, value_network): @@ -323,56 +420,33 @@ def make_optim(optim_cfg, actor_network, value_network): return optim -@hydra.main(config_path=".", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 - - cfg = correct_for_frame_skip(cfg) - model_device = cfg.optim.device - - exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) - - state_dict = get_stats(cfg.env) - logger = make_logger(cfg.logger) - recorder = make_recorder(cfg, logger) - replay_buffer = make_replay_buffer(cfg.replay_buffer) - - actor_network, value_network = make_ddpg_model(cfg) - actor_network = actor_network.to(model_device) - value_network = value_network.to(model_device) - - policy = make_policy(cfg.model, actor_network) - collector = make_collector(cfg, state_dict=state_dict, policy=policy) - loss, target_net_updater = make_loss(cfg.loss, actor_network, value_network) - optim = make_optim(cfg.optim, actor_network, value_network) - - optim_steps_per_batch = cfg.optim.optim_steps_per_batch - batch_size = cfg.optim.batch_size - init_random_frames = cfg.collector.init_random_frames - - pbar = tqdm.tqdm(total=cfg.collector.total_frames) - collected_frames = 0 - for i, data in enumerate(collector): - collected_frames += data.numel() - pbar.update(data.numel()) - # extend replay buffer - replay_buffer.extend(data.view(-1)) - if collected_frames >= init_random_frames: - for j in range(optim_steps_per_batch): - # sample - sample = replay_buffer.sample(batch_size) - # loss - loss_vals = loss(sample) - # backprop - loss_val = sum( - val for key, val in loss_vals.items() if key.startswith("loss") - ) - loss_val.backward() - optim.step() - optim.zero_grad() - target_net_updater.step() - pbar.set_description(f"loss: {loss_val.item(): 4.4f}") - collector.update_policy_weights_() +# ==================================================================== +# Logging and recording +# --------------------- -if __name__ == "__main__": - main() +def make_logger(logger_cfg): + exp_name = generate_exp_name("DDPG", logger_cfg.exp_name) + logger_cfg.exp_name = exp_name + if logger_cfg.logger_class == "wandb": + logger = WandbLogger(exp_name) + else: + raise NotImplementedError + return logger + + +def make_recorder(cfg, logger, policy) -> Recorder: + env_cfg = deepcopy(cfg.env) + env = make_transformed_env(make_base_env(env_cfg, from_pixels=True), env_cfg) + if cfg.recorder.video: + env.insert_transform( + 0, VideoRecorder(logger=logger, tag=cfg.logger.exp_name, in_keys=["pixels"]) + ) + return Recorder( + record_interval=1, + record_frames=cfg.recorder.frames, + frame_skip=env_cfg.frame_skip, + policy_exploration=policy, + recorder=env, + exploration_mode="mean", + ) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 1e3551c3b01..aee521c90e8 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1161,7 +1161,7 @@ def __init__( self.log_pbar = log_pbar @torch.inference_mode() - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase = None) -> Dict: out = None if self._count % self.record_interval == 0: with set_exploration_mode(self.exploration_mode): From fdbb1a5a8253b9be7b36b54f84e374f5fcee160b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 20 Mar 2023 18:12:19 +0000 Subject: [PATCH 03/12] amend --- .../unittest/linux_examples/scripts/run_test.sh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 4e932c12e78..669cc03c5d6 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -47,8 +47,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py env.num_envs=2 \ optim.batch_size=10 \ optim.optim_steps_per_batch=1 \ - record.video=True \ - record.frames=4 \ + recorder.video=True \ + recorder.frames=4 \ replay_buffer.capacity=120 \ env.from_pixels=True python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ @@ -60,8 +60,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py env.num_envs=2 \ optim.batch_size=10 \ optim.optim_steps_per_batch=1 \ - record.video=True \ - record.frames=4 \ + recorder.video=True \ + recorder.frames=4 \ replay_buffer.capacity=120 \ env.from_pixels=False # With single envs @@ -74,8 +74,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py env.num_envs=1 \ optim.batch_size=10 \ optim.optim_steps_per_batch=1 \ - record.video=True \ - record.frames=4 \ + recorder.video=True \ + recorder.frames=4 \ replay_buffer.capacity=120 \ env.from_pixels=True python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ @@ -87,8 +87,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py env.num_envs=1 \ optim.batch_size=10 \ optim.optim_steps_per_batch=1 \ - record.video=True \ - record.frames=4 \ + recorder.video=True \ + recorder.frames=4 \ replay_buffer.capacity=120 \ env.from_pixels=False From 991ebda889c39fc33b1b1933584990fdb5258682 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 20 Mar 2023 21:16:53 +0000 Subject: [PATCH 04/12] amend --- .../linux_examples/scripts/run_test.sh | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 669cc03c5d6..0a110f10545 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -38,19 +38,19 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_ # Batched on/off # # With batched environments -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=16 \ - collector.num_collectors=4 \ - collector.collector_devices=cuda:0 \ - env.num_envs=2 \ - optim.batch_size=10 \ - optim.optim_steps_per_batch=1 \ - recorder.video=True \ - recorder.frames=4 \ - replay_buffer.capacity=120 \ - env.from_pixels=True +#python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ +# collector.total_frames=48 \ +# collector.init_random_frames=10 \ +# collector.frames_per_batch=16 \ +# collector.num_collectors=4 \ +# collector.collector_devices=cuda:0 \ +# env.num_envs=2 \ +# optim.batch_size=10 \ +# optim.optim_steps_per_batch=1 \ +# recorder.video=True \ +# recorder.frames=4 \ +# replay_buffer.capacity=120 \ +# env.from_pixels=True python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -65,19 +65,19 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py replay_buffer.capacity=120 \ env.from_pixels=False # With single envs -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=16 \ - collector.num_collectors=4 \ - collector.collector_devices=cuda:0 \ - env.num_envs=1 \ - optim.batch_size=10 \ - optim.optim_steps_per_batch=1 \ - recorder.video=True \ - recorder.frames=4 \ - replay_buffer.capacity=120 \ - env.from_pixels=True +#python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ +# collector.total_frames=48 \ +# collector.init_random_frames=10 \ +# collector.frames_per_batch=16 \ +# collector.num_collectors=4 \ +# collector.collector_devices=cuda:0 \ +# env.num_envs=1 \ +# optim.batch_size=10 \ +# optim.optim_steps_per_batch=1 \ +# recorder.video=True \ +# recorder.frames=4 \ +# replay_buffer.capacity=120 \ +# env.from_pixels=True python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ From f028289ffdf383b665513f28182bcf9d8274be21 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 08:00:35 +0000 Subject: [PATCH 05/12] get_loggers --- .../linux_examples/scripts/run_test.sh | 60 ++++++++++--------- examples/ddpg/config.yaml | 2 +- examples/ddpg/utils.py | 9 ++- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 0a110f10545..3dd3005bcf2 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -38,19 +38,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_ # Batched on/off # # With batched environments -#python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ -# collector.total_frames=48 \ -# collector.init_random_frames=10 \ -# collector.frames_per_batch=16 \ -# collector.num_collectors=4 \ -# collector.collector_devices=cuda:0 \ -# env.num_envs=2 \ -# optim.batch_size=10 \ -# optim.optim_steps_per_batch=1 \ -# recorder.video=True \ -# recorder.frames=4 \ -# replay_buffer.capacity=120 \ -# env.from_pixels=True python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -63,21 +50,23 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py recorder.video=True \ recorder.frames=4 \ replay_buffer.capacity=120 \ - env.from_pixels=False + env.from_pixels=False \ + logger.backend=csv +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=2 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + recorder.video=True \ + recorder.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=True \ + logger.backend=csv # With single envs -#python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ -# collector.total_frames=48 \ -# collector.init_random_frames=10 \ -# collector.frames_per_batch=16 \ -# collector.num_collectors=4 \ -# collector.collector_devices=cuda:0 \ -# env.num_envs=1 \ -# optim.batch_size=10 \ -# optim.optim_steps_per_batch=1 \ -# recorder.video=True \ -# recorder.frames=4 \ -# replay_buffer.capacity=120 \ -# env.from_pixels=True python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -90,7 +79,22 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py recorder.video=True \ recorder.frames=4 \ replay_buffer.capacity=120 \ - env.from_pixels=False + env.from_pixels=False \ + logger.backend=csv +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.num_collectors=4 \ + collector.collector_devices=cuda:0 \ + env.num_envs=1 \ + optim.batch_size=10 \ + optim.optim_steps_per_batch=1 \ + recorder.video=True \ + recorder.frames=4 \ + replay_buffer.capacity=120 \ + env.from_pixels=True \ + logger.backend=csv python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \ total_frames=48 \ diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 351fae3e24d..b4091a2b70d 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -32,7 +32,7 @@ recorder: # logger logger: - logger_class: wandb + backend: wandb exp_name: ddpg_cheetah_gym # Buffer diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index ed1d0a56621..2569a16ffc8 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -42,7 +42,7 @@ ) from torchrl.objectives import DDPGLoss, SoftUpdate from torchrl.record import VideoRecorder -from torchrl.record.loggers import generate_exp_name, WandbLogger +from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers import Recorder from torchrl.trainers.helpers.envs import LIBS from torchrl.trainers.helpers.models import ACTIVATIONS @@ -428,10 +428,9 @@ def make_optim(optim_cfg, actor_network, value_network): def make_logger(logger_cfg): exp_name = generate_exp_name("DDPG", logger_cfg.exp_name) logger_cfg.exp_name = exp_name - if logger_cfg.logger_class == "wandb": - logger = WandbLogger(exp_name) - else: - raise NotImplementedError + logger = get_logger( + logger_cfg.backend, logger_name="ddpg", experiment_name=exp_name + ) return logger From 5dce0a52a03733de2116c6db743fb746ad691485 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 08:50:45 +0000 Subject: [PATCH 06/12] get_loggers --- examples/ddpg/ddpg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 3e634bc479a..d68b984aa57 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -86,7 +86,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if l0 is None: l0 = loss_val.item() - for key, value in loss_vals.item(): + for key, value in loss_vals.items(): logger.log_scalar(key, value.item(), collected_frames) logger.log_scalar( "reward_training", data["reward"].mean().item(), collected_frames From 84a2ca1155d5eca5d83ce5e7ed8a10eabf02efc2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 09:14:06 +0000 Subject: [PATCH 07/12] yum install mesa-libGLU --- .circleci/unittest/linux_examples/scripts/run_test.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 3dd3005bcf2..6c1814120a9 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -8,6 +8,16 @@ set -e +yum makecache +yum install -y glfw +yum install -y glew +yum install -y mesa-libGL +#yum install -y mesa-libGL-devel +yum install -y mesa-libOSMesa-devel +yum install mesa-libGLU -y +#yum -y install egl-utils +#yum -y install freeglut + eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env From abe08240b1c3ab5caacba9c2a3a8c8ada7f7d840 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 10:09:33 +0000 Subject: [PATCH 08/12] apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 --- .circleci/unittest/linux_examples/scripts/run_test.sh | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 6c1814120a9..b00a54d86a9 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -8,15 +8,7 @@ set -e -yum makecache -yum install -y glfw -yum install -y glew -yum install -y mesa-libGL -#yum install -y mesa-libGL-devel -yum install -y mesa-libOSMesa-devel -yum install mesa-libGLU -y -#yum -y install egl-utils -#yum -y install freeglut +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env From 96f8a2aad2a1db180935e899be5f448d0ce174a6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 11:31:19 +0000 Subject: [PATCH 09/12] change examples image --- .circleci/config.yml | 2 +- .../linux_examples/scripts/run_test.sh | 2 +- .../linux_examples/scripts/setup_env.sh | 42 ++----------------- 3 files changed, 6 insertions(+), 40 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5e878991b8e..7246e3a945c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -375,7 +375,7 @@ jobs: image: ubuntu-2004-cuda-11.4:202110-01 resource_class: gpu.nvidia.medium environment: - image_name: "pytorch/manylinux-cuda117" + image_name: "nvidia/cudagl:11.4.0-base" TAR_OPTIONS: --no-same-owner PYTHON_VERSION: << parameters.python_version >> CU_VERSION: << parameters.cu_version >> diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index b00a54d86a9..cfb32bc80bd 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -8,7 +8,7 @@ set -e -apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 wget freeglut3 freeglut3-dev eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh index c79f25a6979..007d57572a4 100755 --- a/.circleci/unittest/linux_examples/scripts/setup_env.sh +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -9,6 +9,8 @@ set -e this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ + git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" conda_dir="${root_dir}/conda" @@ -71,18 +73,6 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ MUJOCO_GL=$PRIVATE_MUJOCO_GL \ PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL -# Software rendering requires GLX and OSMesa. -if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then - yum makecache - yum install -y glfw - yum install -y glew - yum install -y mesa-libGL - yum install -y mesa-libGL-devel - yum install -y mesa-libOSMesa-devel - yum -y install egl-utils - yum -y install freeglut -fi - pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune @@ -90,29 +80,5 @@ conda env update --file "${this_dir}/environment.yml" --prune conda deactivate conda activate "${env_dir}" -if [[ $OSTYPE != 'darwin'* ]]; then - # install ale-py: manylinux names are broken for CentOS so we need to manually download and - # rename them - PY_VERSION=$(python --version) - if [[ $PY_VERSION == *"3.7"* ]]; then - wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.8"* ]]; then - wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.9"* ]]; then - wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - elif [[ $PY_VERSION == *"3.10"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - fi - pip install "gymnasium[atari,accept-rom-license]" -else - pip install "gymnasium[atari,accept-rom-license]" -fi +pip install ale-py +pip install "gymnasium[atari,accept-rom-license]" From b261690ea2edcfc3cc76d9463929bea63b3eaa18 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 11:59:26 +0000 Subject: [PATCH 10/12] amend --- .circleci/unittest/linux_examples/scripts/install.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.circleci/unittest/linux_examples/scripts/install.sh b/.circleci/unittest/linux_examples/scripts/install.sh index a7c9bb93976..3205f7496ee 100755 --- a/.circleci/unittest/linux_examples/scripts/install.sh +++ b/.circleci/unittest/linux_examples/scripts/install.sh @@ -4,6 +4,8 @@ unset PYTORCH_VERSION # For unittest, nightly PyTorch is used as the following section, # so no need to set PYTORCH_VERSION. # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev set -e From a73ead5112dbd337d013cce4bc83de83e9185acd Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 14:18:41 +0000 Subject: [PATCH 11/12] fix state dict --- examples/ddpg/ddpg.py | 2 +- examples/ddpg/utils.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index d68b984aa57..c2543cb8f82 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -41,7 +41,7 @@ def main(cfg: "DictConfig"): # noqa: F821 logger = make_logger(cfg.logger) replay_buffer = make_replay_buffer(cfg.replay_buffer) - actor_network, value_network = make_ddpg_model(cfg) + actor_network, value_network = make_ddpg_model(cfg, state_dict) actor_network = actor_network.to(model_device) value_network = value_network.to(model_device) diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 2569a16ffc8..85861859f32 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -260,11 +260,14 @@ def make_replay_buffer(rb_cfg): # -def make_ddpg_model(cfg): +def make_ddpg_model(cfg, state_dict): env_cfg = cfg.env model_cfg = cfg.model proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + # we must initialize the observation norm transform + proof_environment.load_state_dict(state_dict) + env_specs = proof_environment.specs from_pixels = env_cfg.from_pixels From 3b249ecc665d51eedfb2f3f0a0e069bdae4b63fd Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Mar 2023 18:03:44 +0000 Subject: [PATCH 12/12] bf --- examples/ddpg/ddpg.py | 2 +- examples/ddpg/utils.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index c2543cb8f82..d68b984aa57 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -41,7 +41,7 @@ def main(cfg: "DictConfig"): # noqa: F821 logger = make_logger(cfg.logger) replay_buffer = make_replay_buffer(cfg.replay_buffer) - actor_network, value_network = make_ddpg_model(cfg, state_dict) + actor_network, value_network = make_ddpg_model(cfg) actor_network = actor_network.to(model_device) value_network = value_network.to(model_device) diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 85861859f32..ce7779a7211 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -194,19 +194,22 @@ def make_parallel_env(env_cfg, state_dict): def get_stats(env_cfg): from_pixels = env_cfg.from_pixels env = make_transformed_env(make_base_env(env_cfg), env_cfg) + init_stats(env, env_cfg.n_samples_stats, from_pixels) + return env.state_dict() + + +def init_stats(env, n_samples_stats, from_pixels): for t in env.transform: if isinstance(t, ObservationNorm): if from_pixels: t.init_stats( - env_cfg.n_samples_stats, + n_samples_stats, cat_dim=-3, reduce_dim=(-1, -2, -3), keep_dims=(-1, -2, -3), ) else: - t.init_stats(env_cfg.n_samples_stats) - - return env.state_dict() + t.init_stats(n_samples_stats) # ==================================================================== @@ -260,13 +263,13 @@ def make_replay_buffer(rb_cfg): # -def make_ddpg_model(cfg, state_dict): +def make_ddpg_model(cfg): env_cfg = cfg.env model_cfg = cfg.model proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) # we must initialize the observation norm transform - proof_environment.load_state_dict(state_dict) + init_stats(proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels) env_specs = proof_environment.specs from_pixels = env_cfg.from_pixels