Skip to content

Commit 9cfbf8d

Browse files
rohitgr7Borda
authored andcommitted
Disable checkpointing, earlystopping and logging with fast_dev_run (#5277)
* Disable checkpointing, earlystopping and logger with fast_dev_run * docs * chlog * disable callbacks and enable DummyLogger * add log * use dummy logger method * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit f740245)
1 parent bb36623 commit 9cfbf8d

File tree

16 files changed

+165
-100
lines changed

16 files changed

+165
-100
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6868

6969
### Fixed
7070

71+
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
72+
7173

7274

7375
## [1.1.2] - 2020-12-23

docs/source/debugging.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ The point is to detect any bugs in the training/validation loop without having t
2828
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
2929

3030
.. testcode::
31-
31+
3232
# runs 1 train, val, test batch and program ends
3333
trainer = Trainer(fast_dev_run=True)
3434

3535
# runs 7 train, val, test batches and program ends
3636
trainer = Trainer(fast_dev_run=7)
3737

38+
.. note::
39+
40+
This argument will disable tuner, checkpoint callbacks, early stopping callbacks,
41+
loggers and logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch.
42+
3843
----------------
3944

4045
Inspect gradient norms

docs/source/trainer.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,9 @@ Under the hood the pseudocode looks like this when running *fast_dev_run* with a
666666
.. note::
667667
668668
This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will
669-
disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be
670-
used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't
671-
disable anything.
669+
disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like
670+
``LearningRateLogger`` and runs for only 1 epoch. This must be used only for debugging purposes.
671+
``limit_train/val/test_batches`` only limits the number of batches and won't disable anything.
672672
673673
flush_logs_every_n_steps
674674
^^^^^^^^^^^^^^^^^^^^^^^^

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,10 @@ def on_validation_end(self, trainer, pl_module):
164164
self._run_early_stopping_check(trainer, pl_module)
165165

166166
def on_validation_epoch_end(self, trainer, pl_module):
167-
if trainer.running_sanity_check:
167+
if trainer.fast_dev_run or trainer.running_sanity_check:
168168
return
169169

170-
if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
170+
if self._validate_condition_metric(trainer.callback_metrics):
171171
# turn off early stopping in on_train_epoch_end
172172
self.based_on_eval_results = True
173173

@@ -176,24 +176,19 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
176176
if self.based_on_eval_results:
177177
return
178178

179-
# early stopping can also work in the train loop when there is no val loop
180-
should_check_early_stop = False
181-
182-
# fallback to monitor key in result dict
183-
if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None:
184-
should_check_early_stop = True
185-
186-
if should_check_early_stop:
187-
self._run_early_stopping_check(trainer, pl_module)
179+
self._run_early_stopping_check(trainer, pl_module)
188180

189181
def _run_early_stopping_check(self, trainer, pl_module):
190182
"""
191183
Checks whether the early stopping condition is met
192184
and if so tells the trainer to stop the training.
193185
"""
194-
logs = trainer.logger_connector.callback_metrics
186+
logs = trainer.callback_metrics
195187

196-
if not self._validate_condition_metric(logs):
188+
if (
189+
trainer.fast_dev_run # disable early_stopping with fast_dev_run
190+
or not self._validate_condition_metric(logs) # short circuit if metric not present
191+
):
197192
return # short circuit if metric not present
198193

199194
current = logs.get(self.monitor)

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import shutil
2525
import subprocess
2626
import time
27-
from typing import List, Tuple, Dict
27+
from typing import Dict, List, Tuple
2828

2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.utilities import rank_zero_only
@@ -213,5 +213,4 @@ def _should_log(trainer) -> bool:
213213
or trainer.should_stop
214214
)
215215

216-
should_log = should_log and not trainer.fast_dev_run
217216
return should_log

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ def on_train_batch_start(self, trainer, *args, **kwargs):
105105
interval = 'step' if self.logging_interval is None else 'any'
106106
latest_stat = self._extract_stats(trainer, interval)
107107

108-
if trainer.logger is not None and latest_stat:
108+
if latest_stat:
109109
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
110110

111111
def on_train_epoch_start(self, trainer, *args, **kwargs):
112112
if self.logging_interval != 'step':
113113
interval = 'epoch' if self.logging_interval is None else 'any'
114114
latest_stat = self._extract_stats(trainer, interval)
115115

116-
if trainer.logger is not None and latest_stat:
116+
if latest_stat:
117117
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
118118

119119
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
@@ -190,5 +190,4 @@ def _should_log(trainer) -> bool:
190190
or trainer.should_stop
191191
)
192192

193-
should_log = should_log and not trainer.fast_dev_run
194193
return should_log

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
2121
"""
2222

23-
from copy import deepcopy
2423
import numbers
2524
import os
26-
from pathlib import Path
2725
import re
26+
from copy import deepcopy
27+
from pathlib import Path
2828
from typing import Any, Dict, Optional, Union
2929

3030
import numpy as np
@@ -215,7 +215,8 @@ def save_checkpoint(self, trainer, pl_module):
215215
global_step = trainer.global_step
216216

217217
if (
218-
self.save_top_k == 0 # no models are saved
218+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
219+
or self.save_top_k == 0 # no models are saved
219220
or self.period < 1 # no models are saved
220221
or (epoch + 1) % self.period # skip epoch
221222
or trainer.running_sanity_check # don't save anything during sanity check
@@ -450,14 +451,14 @@ def __resolve_ckpt_dir(self, trainer, pl_module):
450451
version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))
451452

452453
ckpt_path = os.path.join(
453-
save_dir, name, version, "checkpoints"
454+
save_dir, str(name), version, "checkpoints"
454455
)
455456
else:
456457
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
457458

458459
self.dirpath = ckpt_path
459460

460-
if trainer.is_global_zero:
461+
if not trainer.fast_dev_run and trainer.is_global_zero:
461462
self._fs.makedirs(self.dirpath, exist_ok=True)
462463

463464
def _add_backward_monitor_support(self, trainer):

pytorch_lightning/callbacks/progress.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import importlib
2323
import sys
2424

25-
2625
# check if ipywidgets is installed before importing tqdm.auto
2726
# to ensure it won't fail and a progress bar is displayed
2827
if importlib.util.find_spec('ipywidgets') is not None:
@@ -323,7 +322,7 @@ def on_epoch_start(self, trainer, pl_module):
323322
super().on_epoch_start(trainer, pl_module)
324323
total_train_batches = self.total_train_batches
325324
total_val_batches = self.total_val_batches
326-
if total_train_batches != float('inf') and not trainer.fast_dev_run:
325+
if total_train_batches != float('inf'):
327326
# val can be checked multiple times per epoch
328327
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
329328
total_val_batches = total_val_batches * val_checks_per_epoch

pytorch_lightning/trainer/connectors/debugging_connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1615
from typing import Union
16+
17+
from pytorch_lightning.loggers.base import DummyLogger
1718
from pytorch_lightning.utilities import rank_zero_info
19+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1820

1921

2022
class DebuggingConnector:
@@ -54,11 +56,16 @@ def on_init_start(
5456
limit_train_batches = fast_dev_run
5557
limit_val_batches = fast_dev_run
5658
limit_test_batches = fast_dev_run
59+
self.trainer.max_steps = fast_dev_run
5760
self.trainer.num_sanity_val_steps = 0
5861
self.trainer.max_epochs = 1
62+
self.trainer.val_check_interval = 1.0
63+
self.trainer.check_val_every_n_epoch = 1
64+
self.trainer.logger = DummyLogger()
65+
5966
rank_zero_info(
6067
'Running in fast_dev_run mode: will run a full train,'
61-
f' val and test loop using {fast_dev_run} batch(es)'
68+
f' val and test loop using {fast_dev_run} batch(es).'
6269
)
6370

6471
self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')

pytorch_lightning/trainer/properties.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import os
1616
from abc import ABC
1717
from argparse import ArgumentParser, Namespace
18-
from typing import List, Optional, Type, TypeVar, Union, cast
18+
from typing import cast, List, Optional, Type, TypeVar, Union
1919

2020
from pytorch_lightning.accelerators.accelerator import Accelerator
21-
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBarBase
21+
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.core.optimizer import is_lightning_optimizer
2424
from pytorch_lightning.loggers.base import LightningLoggerBase
@@ -199,7 +199,7 @@ def enable_validation(self) -> bool:
199199
""" Check if we should run validation during training. """
200200
model_ref = self.model_connector.get_model()
201201
val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0
202-
return val_loop_enabled or self.fast_dev_run
202+
return val_loop_enabled
203203

204204
@property
205205
def default_root_dir(self) -> str:
@@ -221,18 +221,38 @@ def weights_save_path(self) -> str:
221221
return os.path.normpath(self._weights_save_path)
222222
return self._weights_save_path
223223

224+
@property
225+
def early_stopping_callback(self) -> Optional[EarlyStopping]:
226+
"""
227+
The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
228+
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
229+
"""
230+
callbacks = self.early_stopping_callbacks
231+
return callbacks[0] if len(callbacks) > 0 else None
232+
233+
@property
234+
def early_stopping_callbacks(self) -> List[EarlyStopping]:
235+
"""
236+
A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
237+
found in the Trainer.callbacks list.
238+
"""
239+
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]
240+
224241
@property
225242
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
226243
"""
227-
The first checkpoint callback in the Trainer.callbacks list, or ``None`` if
228-
no checkpoint callbacks exist.
244+
The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
245+
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
229246
"""
230247
callbacks = self.checkpoint_callbacks
231248
return callbacks[0] if len(callbacks) > 0 else None
232249

233250
@property
234251
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
235-
""" A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """
252+
"""
253+
A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
254+
found in the Trainer.callbacks list.
255+
"""
236256
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
237257

238258
def save_checkpoint(self, filepath, weights_only: bool = False):

0 commit comments

Comments
 (0)