1111The helper functions are coded in the utils.py associated with this script.
1212"""
1313
14+ import time
15+
1416import hydra
1517
1618import numpy as np
1719import torch
1820import torch .cuda
1921import tqdm
20-
22+ from tensordict import TensorDict
2123from torchrl .envs .utils import ExplorationType , set_exploration_type
2224
2325from torchrl .record .loggers import generate_exp_name , get_logger
2426from utils import (
27+ log_metrics ,
2528 make_collector ,
2629 make_environment ,
2730 make_loss_module ,
3538def main (cfg : "DictConfig" ): # noqa: F821
3639 device = torch .device (cfg .network .device )
3740
41+ # Create logger
3842 exp_name = generate_exp_name ("SAC" , cfg .env .exp_name )
3943 logger = None
4044 if cfg .logger .backend :
@@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821
4852 torch .manual_seed (cfg .env .seed )
4953 np .random .seed (cfg .env .seed )
5054
51- # Create Environments
55+ # Create environments
5256 train_env , eval_env = make_environment (cfg )
53- # Create Agent
57+
58+ # Create agent
5459 model , exploration_policy = make_sac_agent (cfg , train_env , eval_env , device )
5560
56- # Create TD3 loss
61+ # Create SAC loss
5762 loss_module , target_net_updater = make_loss_module (cfg , model )
5863
59- # Make Off-Policy Collector
64+ # Create off-policy collector
6065 collector = make_collector (cfg , train_env , exploration_policy )
6166
62- # Make Replay Buffer
67+ # Create replay buffer
6368 replay_buffer = make_replay_buffer (
64- batch_size = cfg .optimization .batch_size ,
69+ batch_size = cfg .optim .batch_size ,
6570 prb = cfg .replay_buffer .prb ,
6671 buffer_size = cfg .replay_buffer .size ,
72+ buffer_scratch_dir = "/tmp/" + cfg .replay_buffer .scratch_dir ,
6773 device = device ,
6874 )
6975
70- # Make Optimizers
71- optimizer = make_sac_optimizer (cfg , loss_module )
72-
73- rewards = []
74- rewards_eval = []
76+ # Create optimizers
77+ (
78+ optimizer_actor ,
79+ optimizer_critic ,
80+ optimizer_alpha ,
81+ ) = make_sac_optimizer (cfg , loss_module )
7582
7683 # Main loop
84+ start_time = time .time ()
7785 collected_frames = 0
7886 pbar = tqdm .tqdm (total = cfg .collector .total_frames )
79- r0 = None
80- q_loss = None
8187
8288 init_random_frames = cfg .collector .init_random_frames
8389 num_updates = int (
8490 cfg .collector .env_per_collector
8591 * cfg .collector .frames_per_batch
86- * cfg .optimization .utd_ratio
92+ * cfg .optim .utd_ratio
8793 )
8894 prb = cfg .replay_buffer .prb
89- env_per_collector = cfg .collector .env_per_collector
9095 eval_iter = cfg .logger .eval_iter
91- frames_per_batch , frame_skip = cfg .collector .frames_per_batch , cfg . env . frame_skip
92- eval_rollout_steps = cfg .collector . max_frames_per_traj // frame_skip
96+ frames_per_batch = cfg .collector .frames_per_batch
97+ eval_rollout_steps = cfg .env . max_episode_steps
9398
99+ sampling_start = time .time ()
94100 for i , tensordict in enumerate (collector ):
95- # update weights of the inference policy
101+ sampling_time = time .time () - sampling_start
102+
103+ # Update weights of the inference policy
96104 collector .update_policy_weights_ ()
97105
98- if r0 is None :
99- r0 = tensordict ["next" , "reward" ].sum (- 1 ).mean ().item ()
100106 pbar .update (tensordict .numel ())
101107
102- tensordict = tensordict .view (- 1 )
108+ tensordict = tensordict .reshape (- 1 )
103109 current_frames = tensordict .numel ()
110+ # Add to replay buffer
104111 replay_buffer .extend (tensordict .cpu ())
105112 collected_frames += current_frames
106113
107- # optimization steps
114+ # Optimization steps
115+ training_start = time .time ()
108116 if collected_frames >= init_random_frames :
109- (actor_losses , q_losses , alpha_losses ) = ([], [], [])
110- for _ in range (num_updates ):
111- # sample from replay buffer
117+ losses = TensorDict (
118+ {},
119+ batch_size = [
120+ num_updates ,
121+ ],
122+ )
123+ for i in range (num_updates ):
124+ # Sample from replay buffer
112125 sampled_tensordict = replay_buffer .sample ().clone ()
113126
127+ # Compute loss
114128 loss_td = loss_module (sampled_tensordict )
115129
116130 actor_loss = loss_td ["loss_actor" ]
117131 q_loss = loss_td ["loss_qvalue" ]
118132 alpha_loss = loss_td ["loss_alpha" ]
119- loss = actor_loss + q_loss + alpha_loss
120133
121- optimizer .zero_grad ()
122- loss .backward ()
123- optimizer .step ()
134+ # Update actor
135+ optimizer_actor .zero_grad ()
136+ actor_loss .backward ()
137+ optimizer_actor .step ()
124138
125- q_losses .append (q_loss .item ())
126- actor_losses .append (actor_loss .item ())
127- alpha_losses .append (alpha_loss .item ())
139+ # Update critic
140+ optimizer_critic .zero_grad ()
141+ q_loss .backward ()
142+ optimizer_critic .step ()
128143
129- # update qnet_target params
144+ # Update alpha
145+ optimizer_alpha .zero_grad ()
146+ alpha_loss .backward ()
147+ optimizer_alpha .step ()
148+
149+ losses [i ] = loss_td .select (
150+ "loss_actor" , "loss_qvalue" , "loss_alpha"
151+ ).detach ()
152+
153+ # Update qnet_target params
130154 target_net_updater .step ()
131155
132- # update priority
156+ # Update priority
133157 if prb :
134158 replay_buffer .update_priority (sampled_tensordict )
135159
136- rewards .append (
137- (i , tensordict ["next" , "reward" ].sum ().item () / env_per_collector )
160+ training_time = time .time () - training_start
161+ episode_end = (
162+ tensordict ["next" , "done" ]
163+ if tensordict ["next" , "done" ].any ()
164+ else tensordict ["next" , "truncated" ]
138165 )
139- train_log = {
140- "train_reward" : rewards [- 1 ][1 ],
141- "collected_frames" : collected_frames ,
142- }
143- if q_loss is not None :
144- train_log .update (
145- {
146- "actor_loss" : np .mean (actor_losses ),
147- "q_loss" : np .mean (q_losses ),
148- "alpha_loss" : np .mean (alpha_losses ),
149- "alpha" : loss_td ["alpha" ],
150- "entropy" : loss_td ["entropy" ],
151- }
166+ episode_rewards = tensordict ["next" , "episode_reward" ][episode_end ]
167+
168+ # Logging
169+ metrics_to_log = {}
170+ if len (episode_rewards ) > 0 :
171+ episode_length = tensordict ["next" , "step_count" ][episode_end ]
172+ metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
173+ metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
174+ episode_length
152175 )
153- if logger is not None :
154- for key , value in train_log .items ():
155- logger .log_scalar (key , value , step = collected_frames )
156- if abs (collected_frames % eval_iter ) < frames_per_batch * frame_skip :
176+ if collected_frames >= init_random_frames :
177+ metrics_to_log ["train/q_loss" ] = losses .get ("loss_qvalue" ).mean ().item ()
178+ metrics_to_log ["train/actor_loss" ] = losses .get ("loss_actor" ).mean ().item ()
179+ metrics_to_log ["train/alpha_loss" ] = losses .get ("loss_alpha" ).mean ().item ()
180+ metrics_to_log ["train/alpha" ] = loss_td ["alpha" ].item ()
181+ metrics_to_log ["train/entropy" ] = loss_td ["entropy" ].item ()
182+ metrics_to_log ["train/sampling_time" ] = sampling_time
183+ metrics_to_log ["train/training_time" ] = training_time
184+
185+ # Evaluation
186+ if abs (collected_frames % eval_iter ) < frames_per_batch :
157187 with set_exploration_type (ExplorationType .MODE ), torch .no_grad ():
188+ eval_start = time .time ()
158189 eval_rollout = eval_env .rollout (
159190 eval_rollout_steps ,
160191 model [0 ],
161192 auto_cast_to_device = True ,
162193 break_when_any_done = True ,
163194 )
195+ eval_time = time .time () - eval_start
164196 eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
165- rewards_eval .append ((i , eval_reward ))
166- eval_str = f"eval cumulative reward: { rewards_eval [- 1 ][1 ]: 4.4f} (init: { rewards_eval [0 ][1 ]: 4.4f} )"
167- if logger is not None :
168- logger .log_scalar (
169- "evaluation_reward" , rewards_eval [- 1 ][1 ], step = collected_frames
170- )
171- if len (rewards_eval ):
172- pbar .set_description (
173- f"reward: { rewards [- 1 ][1 ]: 4.4f} (r0 = { r0 : 4.4f} )," + eval_str
174- )
197+ metrics_to_log ["eval/reward" ] = eval_reward
198+ metrics_to_log ["eval/time" ] = eval_time
199+ if logger is not None :
200+ log_metrics (logger , metrics_to_log , collected_frames )
201+ sampling_start = time .time ()
175202
176203 collector .shutdown ()
204+ end_time = time .time ()
205+ execution_time = end_time - start_time
206+ print (f"Training took { execution_time :.2f} seconds to finish" )
177207
178208
179209if __name__ == "__main__" :
0 commit comments