Skip to content

Commit 09167ef

Browse files
BordaAdrian WälchliwilliamFalcon
authored
Checkpointing interval (#1272)
* formatting * formatting * fix interval * fix train loop * fix test * parametrize test * Apply suggestions from code review Co-Authored-By: Adrian Wälchli <[email protected]> * fix calling * flake8 * add types Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent 3476d2f commit 09167ef

File tree

15 files changed

+166
-298
lines changed

15 files changed

+166
-298
lines changed

pl_examples/multi_node_examples/multi_node_ddp_demo.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616

1717

1818
def main(hparams):
19-
"""
20-
Main training routine specific for this project
21-
:param hparams:
22-
:return:
23-
"""
19+
"""Main training routine specific for this project."""
2420
# ------------------------
2521
# 1 INIT LIGHTNING MODEL
2622
# ------------------------

pytorch_lightning/callbacks/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
r"""
22
Callback Base
3-
==============
3+
=============
44
Abstract base class used to build new callbacks.
55
"""
66

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
9494
self.save_top_k = save_top_k
9595
self.save_weights_only = save_weights_only
9696
self.period = period
97-
self.epochs_since_last_check = 0
97+
self.epoch_last_check = None
9898
self.prefix = prefix
9999
self.best_k_models = {}
100100
# {filename: monitor}
@@ -139,21 +139,20 @@ def check_monitor_top_k(self, current):
139139
def format_checkpoint_name(self, epoch, metrics, ver=None):
140140
"""Generate a filename according define template.
141141
142-
Examples
143-
--------
144-
>>> tmpdir = os.path.dirname(__file__)
145-
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
146-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
147-
'epoch=0.ckpt'
148-
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
149-
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
150-
'epoch=005.ckpt'
151-
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
152-
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
153-
'epoch=2-val_loss=0.12.ckpt'
154-
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
155-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
156-
'missing=0.ckpt'
142+
Examples:
143+
>>> tmpdir = os.path.dirname(__file__)
144+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
145+
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
146+
'epoch=0.ckpt'
147+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
148+
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
149+
'epoch=005.ckpt'
150+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
151+
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
152+
'epoch=2-val_loss=0.12.ckpt'
153+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
154+
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
155+
'missing=0.ckpt'
157156
"""
158157
# check if user passed in keys to the string
159158
groups = re.findall(r'(\{.*?)[:\}]', self.filename)
@@ -181,41 +180,36 @@ def on_validation_end(self, trainer, pl_module):
181180

182181
metrics = trainer.callback_metrics
183182
epoch = trainer.current_epoch
184-
self.epochs_since_last_check += 1
185-
186183
if self.save_top_k == 0:
187184
# no models are saved
188185
return
189-
if self.epochs_since_last_check >= self.period:
190-
self.epochs_since_last_check = 0
191-
192-
filepath = self.format_checkpoint_name(epoch, metrics)
193-
version_cnt = 0
194-
while os.path.isfile(filepath):
195-
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
196-
# this epoch called before
197-
version_cnt += 1
198-
199-
if self.save_top_k != -1:
200-
current = metrics.get(self.monitor)
201-
202-
if current is None:
203-
warnings.warn(
204-
f'Can save best model only with {self.monitor} available,'
205-
' skipping.', RuntimeWarning)
206-
else:
207-
if self.check_monitor_top_k(current):
208-
self._do_check_save(filepath, current, epoch)
209-
else:
210-
if self.verbose > 0:
211-
log.info(
212-
f'\nEpoch {epoch:05d}: {self.monitor}'
213-
f' was not in top {self.save_top_k}')
214-
215-
else:
216-
if self.verbose > 0:
217-
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
218-
self._save_model(filepath)
186+
if self.epoch_last_check is not None and (epoch - self.epoch_last_check) < self.period:
187+
# skipping in this term
188+
return
189+
190+
self.epoch_last_check = epoch
191+
192+
filepath = self.format_checkpoint_name(epoch, metrics)
193+
version_cnt = 0
194+
while os.path.isfile(filepath):
195+
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
196+
# this epoch called before
197+
version_cnt += 1
198+
199+
if self.save_top_k != -1:
200+
current = metrics.get(self.monitor)
201+
202+
if current is None:
203+
warnings.warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning)
204+
elif self.check_monitor_top_k(current):
205+
self._do_check_save(filepath, current, epoch)
206+
elif self.verbose > 0:
207+
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
208+
209+
else:
210+
if self.verbose > 0:
211+
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
212+
self._save_model(filepath)
219213

220214
def _do_check_save(self, filepath, current, epoch):
221215
# remove kth

pytorch_lightning/profiler/profiler.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ class BaseProfiler(ABC):
1717
"""
1818

1919
@abstractmethod
20-
def start(self, action_name):
20+
def start(self, action_name: str) -> None:
2121
"""Defines how to start recording an action."""
2222

2323
@abstractmethod
24-
def stop(self, action_name):
24+
def stop(self, action_name: str) -> None:
2525
"""Defines how to record the duration once an action is complete."""
2626

2727
@contextmanager
28-
def profile(self, action_name):
28+
def profile(self, action_name: str) -> None:
2929
"""
3030
Yields a context manager to encapsulate the scope of a profiled action.
3131
@@ -43,7 +43,7 @@ def profile(self, action_name):
4343
finally:
4444
self.stop(action_name)
4545

46-
def profile_iterable(self, iterable, action_name):
46+
def profile_iterable(self, iterable, action_name: str) -> None:
4747
iterator = iter(iterable)
4848
while True:
4949
try:
@@ -55,7 +55,7 @@ def profile_iterable(self, iterable, action_name):
5555
self.stop(action_name)
5656
break
5757

58-
def describe(self):
58+
def describe(self) -> None:
5959
"""Logs a profile report after the conclusion of the training run."""
6060
pass
6161

@@ -69,10 +69,10 @@ class PassThroughProfiler(BaseProfiler):
6969
def __init__(self):
7070
pass
7171

72-
def start(self, action_name):
72+
def start(self, action_name: str) -> None:
7373
pass
7474

75-
def stop(self, action_name):
75+
def stop(self, action_name: str) -> None:
7676
pass
7777

7878

@@ -86,14 +86,14 @@ def __init__(self):
8686
self.current_actions = {}
8787
self.recorded_durations = defaultdict(list)
8888

89-
def start(self, action_name):
89+
def start(self, action_name: str) -> None:
9090
if action_name in self.current_actions:
9191
raise ValueError(
9292
f"Attempted to start {action_name} which has already started."
9393
)
9494
self.current_actions[action_name] = time.monotonic()
9595

96-
def stop(self, action_name):
96+
def stop(self, action_name: str) -> None:
9797
end_time = time.monotonic()
9898
if action_name not in self.current_actions:
9999
raise ValueError(
@@ -103,7 +103,7 @@ def stop(self, action_name):
103103
duration = end_time - start_time
104104
self.recorded_durations[action_name].append(duration)
105105

106-
def describe(self):
106+
def describe(self) -> None:
107107
output_string = "\n\nProfiler Report\n"
108108

109109
def log_row(action, mean, total):
@@ -126,32 +126,33 @@ class AdvancedProfiler(BaseProfiler):
126126
verbose and you should only use this if you want very detailed reports.
127127
"""
128128

129-
def __init__(self, output_filename=None, line_count_restriction=1.0):
129+
def __init__(self, output_filename: str = None, line_count_restriction: float = 1.0):
130130
"""
131-
:param output_filename (str): optionally save profile results to file instead of printing
132-
to std out when training is finished.
133-
:param line_count_restriction (int|float): this can be used to limit the number of functions
134-
reported for each action. either an integer (to select a count of lines),
135-
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
131+
Args:
132+
output_filename: optionally save profile results to file instead of printing
133+
to std out when training is finished.
134+
line_count_restriction: this can be used to limit the number of functions
135+
reported for each action. either an integer (to select a count of lines),
136+
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
136137
"""
137138
self.profiled_actions = {}
138139
self.output_filename = output_filename
139140
self.line_count_restriction = line_count_restriction
140141

141-
def start(self, action_name):
142+
def start(self, action_name: str) -> None:
142143
if action_name not in self.profiled_actions:
143144
self.profiled_actions[action_name] = cProfile.Profile()
144145
self.profiled_actions[action_name].enable()
145146

146-
def stop(self, action_name):
147+
def stop(self, action_name: str) -> None:
147148
pr = self.profiled_actions.get(action_name)
148149
if pr is None:
149150
raise ValueError( # pragma: no-cover
150151
f"Attempting to stop recording an action ({action_name}) which was never started."
151152
)
152153
pr.disable()
153154

154-
def describe(self):
155+
def describe(self) -> None:
155156
self.recorded_stats = {}
156157
for action_name, pr in self.profiled_actions.items():
157158
s = io.StringIO()

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,9 @@ def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
199199
self.use_ddp2 = distributed_backend == 'ddp2'
200200

201201
elif distributed_backend is None:
202-
m = 'You requested multiple GPUs but did not specify a backend' \
203-
'Trainer(distributed_backend=dp) (or ddp, ddp2)' \
204-
'Setting distributed_backend=dp for you'
205-
warnings.warn(m)
202+
warnings.warn('You requested multiple GPUs but did not specify a backend, e.g.'
203+
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
204+
' Setting distributed_backend=dp for you.')
206205
self.use_dp = True
207206
self.use_ddp = False
208207
self.use_ddp2 = False

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,8 @@ def tpu_train(self, tpu_core_idx, model):
491491
if self.precision == 16:
492492
os.environ['XLA_USE_BF16'] = str(1)
493493

494-
m = f'INIT TPU local core: {self.tpu_local_core_rank}, ' \
495-
f'global rank: {self.tpu_global_core_rank}'
496-
log.info(m)
494+
log.info(f'INIT TPU local core: {self.tpu_local_core_rank},'
495+
f' global rank: {self.tpu_global_core_rank}')
497496

498497
# continue training routine
499498
self.run_pretrain_routine(model)
@@ -512,12 +511,10 @@ def dp_train(self, model):
512511
# https://github.com/NVIDIA/apex/issues/227
513512
if self.use_dp and self.use_amp:
514513
if self.amp_level == 'O2':
515-
m = f"""
516-
Amp level {self.amp_level} with DataParallel is not supported.
517-
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.
518-
We recommend you switch to ddp if you want to use amp
519-
"""
520-
raise MisconfigurationException(m)
514+
raise MisconfigurationException(
515+
f'Amp level {self.amp_level} with DataParallel is not supported.'
516+
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
517+
f' We recommend you switch to ddp if you want to use amp')
521518
else:
522519
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
523520

@@ -584,11 +581,10 @@ def sanitize_gpu_ids(gpus):
584581
all_available_gpus = get_all_available_gpus()
585582
for gpu in gpus:
586583
if gpu not in all_available_gpus:
587-
message = f"""
588-
You requested GPUs: {gpus}
589-
But your machine only has: {all_available_gpus}
590-
"""
591-
raise MisconfigurationException(message)
584+
raise MisconfigurationException(f"""
585+
You requested GPUs: {gpus}
586+
But your machine only has: {all_available_gpus}
587+
""")
592588
return gpus
593589

594590

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
322322
def run_evaluation(self, test_mode: bool = False):
323323
# when testing make sure user defined a test step
324324
if test_mode and not self.is_overriden('test_step'):
325-
m = "You called `.test()` without defining model's `.test_step()`." \
326-
" Please define and try again"
327-
raise MisconfigurationException(m)
325+
raise MisconfigurationException(
326+
"You called `.test()` without defining model's `.test_step()`."
327+
" Please define and try again")
328328

329329
# Validation/Test begin callbacks
330330
if test_mode:

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,8 @@ def __init__(
328328
if self.fast_dev_run:
329329
self.num_sanity_val_steps = 1
330330
self.max_epochs = 1
331-
m = '''
332-
Running in fast_dev_run mode: will run a full train,
333-
val loop using a single batch
334-
'''
335-
log.info(m)
331+
log.info('Running in fast_dev_run mode: will run a full train,'
332+
' val loop using a single batch')
336333

337334
# set default save path if user didn't provide one
338335
self.default_save_path = default_save_path
@@ -739,22 +736,22 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da
739736
# functions to overwrite with these implementations
740737
if train_dataloader is not None:
741738
if not self.is_overriden('training_step', model):
742-
m = 'You called .fit() with a train_dataloader but did not define training_step()'
743-
raise MisconfigurationException(m)
739+
raise MisconfigurationException(
740+
'You called `.fit()` with a `train_dataloader` but did not define `training_step()`')
744741

745742
model.train_dataloader = _PatchDataLoader(train_dataloader)
746743

747744
if val_dataloaders is not None:
748745
if not self.is_overriden('validation_step', model):
749-
m = 'You called .fit() with a val_dataloaders but did not define validation_step()'
750-
raise MisconfigurationException(m)
746+
raise MisconfigurationException(
747+
'You called `.fit()` with a `val_dataloaders` but did not define `validation_step()`')
751748

752749
model.val_dataloader = _PatchDataLoader(val_dataloaders)
753750

754751
if test_dataloaders is not None:
755752
if not self.is_overriden('test_step', model):
756-
m = 'You called .fit() with a test_dataloaders but did not define test_step()'
757-
raise MisconfigurationException(m)
753+
raise MisconfigurationException(
754+
'You called `.fit()` with a `test_dataloaders` but did not define `test_step()`')
758755

759756
model.test_dataloader = _PatchDataLoader(test_dataloaders)
760757

@@ -855,8 +852,7 @@ def run_pretrain_routine(self, model: LightningModule):
855852
if self.weights_summary in ['full', 'top']:
856853
ref_model.summarize(mode=self.weights_summary)
857854
else:
858-
m = "weights_summary can be None, 'full' or 'top'"
859-
raise MisconfigurationException(m)
855+
raise MisconfigurationException("weights_summary can be None, 'full' or 'top'")
860856

861857
# track model now.
862858
# if cluster resets state, the model will update with the saved weights

0 commit comments

Comments
 (0)