@@ -73,13 +73,6 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs):
7373 self .trainer = trainer
7474 self .training_loop .connect (trainer )
7575
76- # TODO: is it used anywhere?
77- def should_accumulate (self ):
78- return self .training_loop .batch_loop .should_accumulate ()
79-
80- def get_active_optimizers (self , batch_idx ):
81- return self .training_loop .batch_loop .get_active_optimizers (batch_idx )
82-
8376 @property
8477 def done (self ) -> bool :
8578 # TODO: Move track steps inside training loop and move part of these condition inside training loop
@@ -109,36 +102,6 @@ def on_run_start(self):
109102 # hook
110103 self .trainer .call_hook ("on_train_start" )
111104
112- def on_run_end (self ):
113- if self ._teardown_already_run :
114- return
115- self ._teardown_already_run = True
116-
117- # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
118- # when a checkpoint was saved at the last step
119- self .training_loop .global_step -= 1
120- # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
121- self .check_checkpoint_callback (should_update = True , is_last = True )
122- self .training_loop .global_step += 1
123-
124- # hook
125- self .trainer .call_hook ("on_train_end" )
126-
127- # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
128- # It might be related to xla tensors blocked when moving the cpu
129- # kill loggers
130- if self .trainer .logger is not None :
131- self .trainer .logger .finalize ("success" )
132-
133- # summarize profile results
134- self .trainer .profiler .describe ()
135-
136- # give accelerators a chance to finish
137- self .trainer .accelerator .on_train_end ()
138-
139- # reset bookkeeping
140- self .trainer ._running_stage = None
141-
142105 def on_advance_start (self ): # equal to on train epoch start
143106 # implemented here since this code has to be run always no matter the actual epoch implementation
144107 epoch = self .iteration_count + 1
@@ -167,7 +130,14 @@ def on_advance_start(self): # equal to on train epoch start
167130 self .trainer .call_hook ("on_epoch_start" )
168131 self .trainer .call_hook ("on_train_epoch_start" )
169132
170- # why is this not the same as the old on_train_epoch_end?
133+ def advance (self ):
134+
135+ with self .trainer .profiler .profile ("run_training_epoch" ):
136+ # run train epoch
137+ epoch_output = self .training_loop .run ()
138+ # log epoch metrics
139+ self .trainer .logger_connector .log_train_epoch_end_metrics (epoch_output )
140+
171141 def on_advance_end (self ):
172142 # # handle epoch_output on epoch end
173143 # self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now
@@ -193,13 +163,42 @@ def on_advance_end(self):
193163 # TODO: move inside training_loop.on_run_end? equivalent? order?
194164 self .training_loop .increment_accumulated_grad_global_step ()
195165
196- def advance (self ):
166+ # why is this not the same as the old on_train_epoch_end?
167+ def on_run_end (self ):
168+ if self ._teardown_already_run :
169+ return
170+ self ._teardown_already_run = True
197171
198- with self .trainer .profiler .profile ("run_training_epoch" ):
199- # run train epoch
200- epoch_output = self .training_loop .run ()
201- # log epoch metrics
202- self .trainer .logger_connector .log_train_epoch_end_metrics (epoch_output )
172+ # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
173+ # when a checkpoint was saved at the last step
174+ self .training_loop .global_step -= 1
175+ # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
176+ self .check_checkpoint_callback (should_update = True , is_last = True )
177+ self .training_loop .global_step += 1
178+
179+ # hook
180+ self .trainer .call_hook ("on_train_end" )
181+
182+ # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
183+ # It might be related to xla tensors blocked when moving the cpu
184+ # kill loggers
185+ if self .trainer .logger is not None :
186+ self .trainer .logger .finalize ("success" )
187+
188+ # summarize profile results
189+ self .trainer .profiler .describe ()
190+
191+ # give accelerators a chance to finish
192+ self .trainer .accelerator .on_train_end ()
193+
194+ # reset bookkeeping
195+ self .trainer ._running_stage = None
196+
197+ def should_accumulate (self ):
198+ return self .training_loop .batch_loop .should_accumulate ()
199+
200+ def get_active_optimizers (self , batch_idx ):
201+ return self .training_loop .batch_loop .get_active_optimizers (batch_idx )
203202
204203 def check_checkpoint_callback (self , should_update , is_last = False ):
205204 # TODO bake this logic into the ModelCheckpoint callback
@@ -213,3 +212,4 @@ def check_checkpoint_callback(self, should_update, is_last=False):
213212
214213 for cb in callbacks :
215214 cb .on_validation_end (self .trainer , model )
215+
0 commit comments