Skip to content

Commit b090e4f

Browse files
committed
standardize order of loop hooks
1 parent 57b4a32 commit b090e4f

File tree

4 files changed

+125
-125
lines changed

4 files changed

+125
-125
lines changed

pytorch_lightning/loops/base.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from _weakref import proxy
2-
from abc import ABCMeta, abstractmethod
3-
from typing import Any, Counter, List, Optional
2+
from abc import abstractmethod, ABC
3+
from typing import Any, Optional
44

55
import pytorch_lightning as pl
66

77

8-
class Loop(metaclass=ABCMeta):
8+
class Loop(ABC):
99

1010
def __init__(self):
1111
self.iteration_count: int = 0
@@ -21,22 +21,6 @@ def connect(self, trainer, *args, **kwargs):
2121
def done(self):
2222
"""Property indicating when loop is finished"""
2323

24-
@abstractmethod
25-
def advance(self, *args: Any, **kwargs: Any):
26-
"""What to do within a single step"""
27-
28-
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
29-
pass
30-
31-
def on_run_end(self) -> Any:
32-
pass
33-
34-
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
35-
pass
36-
37-
def on_advance_end(self) -> None:
38-
pass
39-
4024
def run(self, *args: Any, **kwargs: Any):
4125
self.on_run_start(*args, **kwargs)
4226

@@ -49,5 +33,21 @@ def run(self, *args: Any, **kwargs: Any):
4933

5034
return self.on_run_end()
5135

36+
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
37+
pass
38+
39+
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
40+
pass
41+
42+
@abstractmethod
43+
def advance(self, *args: Any, **kwargs: Any):
44+
"""What to do within a single step"""
45+
46+
def on_advance_end(self) -> None:
47+
pass
48+
49+
def on_run_end(self) -> Any:
50+
pass
51+
5252
def state_dict(self):
5353
return dict()

pytorch_lightning/loops/batch_loop.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,31 @@ def connect(self, trainer, *args, **kwargs):
4444
def done(self):
4545
return len(self._remaining_splits) == 0
4646

47+
def run(self, batch, batch_idx, dataloader_idx):
48+
if batch is None:
49+
return AttributeDict(signal=0, grad_norm_dic={})
50+
51+
# hook
52+
response = self.trainer.call_hook("on_batch_start")
53+
if response == -1:
54+
return AttributeDict(signal=-1, grad_norm_dic={})
55+
56+
# hook
57+
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx)
58+
if response == -1:
59+
return AttributeDict(signal=-1, grad_norm_dic={})
60+
61+
super().run(batch, batch_idx, dataloader_idx)
62+
63+
output = AttributeDict(
64+
signal=0,
65+
# todo: Properly aggregate grad_norm accros opt_idx and split_idx
66+
# grad_norm_dict=grad_norm_dict,
67+
grad_norm_dict={},
68+
training_step_output_for_epoch_end=self.batch_outputs,
69+
)
70+
return output
71+
4772
def on_run_start(self, batch, batch_idx, dataloader_idx):
4873
self._hiddens = None
4974
self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch)))
@@ -70,31 +95,6 @@ def advance(self, batch, batch_idx, dataloader_idx):
7095
if result:
7196
self.batch_outputs[0].append(result.training_step_output_for_epoch_end)
7297

73-
def run(self, batch, batch_idx, dataloader_idx):
74-
if batch is None:
75-
return AttributeDict(signal=0, grad_norm_dic={})
76-
77-
# hook
78-
response = self.trainer.call_hook("on_batch_start")
79-
if response == -1:
80-
return AttributeDict(signal=-1, grad_norm_dic={})
81-
82-
# hook
83-
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx)
84-
if response == -1:
85-
return AttributeDict(signal=-1, grad_norm_dic={})
86-
87-
super().run(batch, batch_idx, dataloader_idx)
88-
89-
output = AttributeDict(
90-
signal=0,
91-
# todo: Properly aggregate grad_norm accros opt_idx and split_idx
92-
# grad_norm_dict=grad_norm_dict,
93-
grad_norm_dict={},
94-
training_step_output_for_epoch_end=self.batch_outputs,
95-
)
96-
return output
97-
9898
# ------------------------------------------------------------------------------------------------------------
9999
# HELPER --- TO BE CLEANED UP
100100
# ------------------------------------------------------------------------------------------------------------

pytorch_lightning/loops/epoch_loop.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,6 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs):
7373
self.trainer = trainer
7474
self.training_loop.connect(trainer)
7575

76-
# TODO: is it used anywhere?
77-
def should_accumulate(self):
78-
return self.training_loop.batch_loop.should_accumulate()
79-
80-
def get_active_optimizers(self, batch_idx):
81-
return self.training_loop.batch_loop.get_active_optimizers(batch_idx)
82-
8376
@property
8477
def done(self) -> bool:
8578
# TODO: Move track steps inside training loop and move part of these condition inside training loop
@@ -109,36 +102,6 @@ def on_run_start(self):
109102
# hook
110103
self.trainer.call_hook("on_train_start")
111104

112-
def on_run_end(self):
113-
if self._teardown_already_run:
114-
return
115-
self._teardown_already_run = True
116-
117-
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
118-
# when a checkpoint was saved at the last step
119-
self.training_loop.global_step -= 1
120-
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
121-
self.check_checkpoint_callback(should_update=True, is_last=True)
122-
self.training_loop.global_step += 1
123-
124-
# hook
125-
self.trainer.call_hook("on_train_end")
126-
127-
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
128-
# It might be related to xla tensors blocked when moving the cpu
129-
# kill loggers
130-
if self.trainer.logger is not None:
131-
self.trainer.logger.finalize("success")
132-
133-
# summarize profile results
134-
self.trainer.profiler.describe()
135-
136-
# give accelerators a chance to finish
137-
self.trainer.accelerator.on_train_end()
138-
139-
# reset bookkeeping
140-
self.trainer._running_stage = None
141-
142105
def on_advance_start(self): # equal to on train epoch start
143106
# implemented here since this code has to be run always no matter the actual epoch implementation
144107
epoch = self.iteration_count + 1
@@ -167,7 +130,14 @@ def on_advance_start(self): # equal to on train epoch start
167130
self.trainer.call_hook("on_epoch_start")
168131
self.trainer.call_hook("on_train_epoch_start")
169132

170-
# why is this not the same as the old on_train_epoch_end?
133+
def advance(self):
134+
135+
with self.trainer.profiler.profile("run_training_epoch"):
136+
# run train epoch
137+
epoch_output = self.training_loop.run()
138+
# log epoch metrics
139+
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
140+
171141
def on_advance_end(self):
172142
# # handle epoch_output on epoch end
173143
# self.on_train_epoch_end(outputs) # Handled in on_run_end of training_loop now
@@ -193,13 +163,42 @@ def on_advance_end(self):
193163
# TODO: move inside training_loop.on_run_end? equivalent? order?
194164
self.training_loop.increment_accumulated_grad_global_step()
195165

196-
def advance(self):
166+
# why is this not the same as the old on_train_epoch_end?
167+
def on_run_end(self):
168+
if self._teardown_already_run:
169+
return
170+
self._teardown_already_run = True
197171

198-
with self.trainer.profiler.profile("run_training_epoch"):
199-
# run train epoch
200-
epoch_output = self.training_loop.run()
201-
# log epoch metrics
202-
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
172+
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
173+
# when a checkpoint was saved at the last step
174+
self.training_loop.global_step -= 1
175+
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
176+
self.check_checkpoint_callback(should_update=True, is_last=True)
177+
self.training_loop.global_step += 1
178+
179+
# hook
180+
self.trainer.call_hook("on_train_end")
181+
182+
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
183+
# It might be related to xla tensors blocked when moving the cpu
184+
# kill loggers
185+
if self.trainer.logger is not None:
186+
self.trainer.logger.finalize("success")
187+
188+
# summarize profile results
189+
self.trainer.profiler.describe()
190+
191+
# give accelerators a chance to finish
192+
self.trainer.accelerator.on_train_end()
193+
194+
# reset bookkeeping
195+
self.trainer._running_stage = None
196+
197+
def should_accumulate(self):
198+
return self.training_loop.batch_loop.should_accumulate()
199+
200+
def get_active_optimizers(self, batch_idx):
201+
return self.training_loop.batch_loop.get_active_optimizers(batch_idx)
203202

204203
def check_checkpoint_callback(self, should_update, is_last=False):
205204
# TODO bake this logic into the ModelCheckpoint callback
@@ -213,3 +212,4 @@ def check_checkpoint_callback(self, should_update, is_last=False):
213212

214213
for cb in callbacks:
215214
cb.on_validation_end(self.trainer, model)
215+

pytorch_lightning/loops/training_loop.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,43 @@ def connect(self, trainer: 'pl.Trainer', *args, **kwargs):
4545
self.batch_loop = BatchLoop()
4646
self.batch_loop.connect(trainer)
4747

48+
@property
49+
def done(self):
50+
# max steps reached, end training
51+
if (
52+
self.max_steps is not None and self.max_steps <= self.global_step + 1
53+
and self.batch_loop._accumulated_batches_reached()
54+
):
55+
return True
56+
57+
# end epoch early
58+
# stop when the flag is changed or we've gone past the amount
59+
# requested in the batches
60+
if self.trainer.should_stop:
61+
return True
62+
63+
# TODO: moved to on_advance_end, check if correct?
64+
# self.total_batch_idx += 1
65+
66+
# stop epoch if we limited the number of training batches
67+
if self._num_training_batches_reached(self.is_last_batch):
68+
return True
69+
70+
def run(self, *args, **kwargs):
71+
self.on_run_start()
72+
73+
while True:
74+
try:
75+
self.on_advance_start()
76+
self.advance()
77+
self.on_advance_end()
78+
except StopIteration:
79+
break
80+
81+
self.iteration_count += 1
82+
83+
return self.on_run_end()
84+
4885
def on_run_start(self):
4986
# modify dataloader if needed (ddp, etc...)
5087
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
@@ -121,28 +158,6 @@ def on_advance_end(self):
121158
# progress global step according to grads progress
122159
self.increment_accumulated_grad_global_step()
123160

124-
@property
125-
def done(self):
126-
# max steps reached, end training
127-
if (
128-
self.max_steps is not None and self.max_steps <= self.global_step + 1
129-
and self.batch_loop._accumulated_batches_reached()
130-
):
131-
return True
132-
133-
# end epoch early
134-
# stop when the flag is changed or we've gone past the amount
135-
# requested in the batches
136-
if self.trainer.should_stop:
137-
return True
138-
139-
# TODO: moved to on_advance_end, check if correct?
140-
# self.total_batch_idx += 1
141-
142-
# stop epoch if we limited the number of training batches
143-
if self._num_training_batches_reached(self.is_last_batch):
144-
return True
145-
146161
# this is the old on train_epoch_end?
147162
def on_run_end(self):
148163
# inform logger the batch loop has finished
@@ -176,21 +191,6 @@ def on_run_end(self):
176191
self.trainer.call_hook('on_epoch_end')
177192
return self.epoch_output
178193

179-
def run(self, *args, **kwargs):
180-
self.on_run_start()
181-
182-
while True:
183-
try:
184-
self.on_advance_start()
185-
self.advance()
186-
self.on_advance_end()
187-
except StopIteration:
188-
break
189-
190-
self.iteration_count += 1
191-
192-
return self.on_run_end()
193-
194194
# ------------------------------------------------------------------------------------------------------------
195195
# HELPER --- TO BE CLEANED UP
196196
# ------------------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)