Skip to content

Commit f740245

Browse files
rohitgr7awaelchlitchatonBorda
authored
Disable checkpointing, earlystopping and logging with fast_dev_run (#5277)
* Disable checkpointing, earlystopping and logger with fast_dev_run * docs * chlog * disable callbacks and enable DummyLogger * add log * use dummy logger method * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent b0051e8 commit f740245

File tree

16 files changed

+168
-103
lines changed

16 files changed

+168
-103
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323

2424
### Fixed
2525

26+
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
27+
2628

2729

2830
## [1.1.2] - 2020-12-23

docs/source/debugging.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ The point is to detect any bugs in the training/validation loop without having t
2828
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
2929

3030
.. testcode::
31-
31+
3232
# runs 1 train, val, test batch and program ends
3333
trainer = Trainer(fast_dev_run=True)
3434

3535
# runs 7 train, val, test batches and program ends
3636
trainer = Trainer(fast_dev_run=7)
3737

38+
.. note::
39+
40+
This argument will disable tuner, checkpoint callbacks, early stopping callbacks,
41+
loggers and logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch.
42+
3843
----------------
3944

4045
Inspect gradient norms

docs/source/trainer.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,9 @@ Under the hood the pseudocode looks like this when running *fast_dev_run* with a
666666
.. note::
667667
668668
This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will
669-
disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be
670-
used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't
671-
disable anything.
669+
disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like
670+
``LearningRateLogger`` and runs for only 1 epoch. This must be used only for debugging purposes.
671+
``limit_train/val/test_batches`` only limits the number of batches and won't disable anything.
672672
673673
flush_logs_every_n_steps
674674
^^^^^^^^^^^^^^^^^^^^^^^^

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning import _logger as log
2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.metrics.metric import Metric
31-
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn
31+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE
3232

3333

3434
class EarlyStopping(Callback):
@@ -166,10 +166,10 @@ def on_validation_end(self, trainer, pl_module):
166166
self._run_early_stopping_check(trainer, pl_module)
167167

168168
def on_validation_epoch_end(self, trainer, pl_module):
169-
if trainer.running_sanity_check:
169+
if trainer.fast_dev_run or trainer.running_sanity_check:
170170
return
171171

172-
if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
172+
if self._validate_condition_metric(trainer.callback_metrics):
173173
# turn off early stopping in on_train_epoch_end
174174
self.based_on_eval_results = True
175175

@@ -178,24 +178,19 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
178178
if self.based_on_eval_results:
179179
return
180180

181-
# early stopping can also work in the train loop when there is no val loop
182-
should_check_early_stop = False
183-
184-
# fallback to monitor key in result dict
185-
if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None:
186-
should_check_early_stop = True
187-
188-
if should_check_early_stop:
189-
self._run_early_stopping_check(trainer, pl_module)
181+
self._run_early_stopping_check(trainer, pl_module)
190182

191183
def _run_early_stopping_check(self, trainer, pl_module):
192184
"""
193185
Checks whether the early stopping condition is met
194186
and if so tells the trainer to stop the training.
195187
"""
196-
logs = trainer.logger_connector.callback_metrics
188+
logs = trainer.callback_metrics
197189

198-
if not self._validate_condition_metric(logs):
190+
if (
191+
trainer.fast_dev_run # disable early_stopping with fast_dev_run
192+
or not self._validate_condition_metric(logs) # short circuit if metric not present
193+
):
199194
return # short circuit if metric not present
200195

201196
current = logs.get(self.monitor)

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import shutil
2525
import subprocess
2626
import time
27-
from typing import List, Tuple, Dict
27+
from typing import Dict, List, Tuple
2828

2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.utilities import rank_zero_only
@@ -213,5 +213,4 @@ def _should_log(trainer) -> bool:
213213
or trainer.should_stop
214214
)
215215

216-
should_log = should_log and not trainer.fast_dev_run
217216
return should_log

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ def on_train_batch_start(self, trainer, *args, **kwargs):
105105
interval = 'step' if self.logging_interval is None else 'any'
106106
latest_stat = self._extract_stats(trainer, interval)
107107

108-
if trainer.logger is not None and latest_stat:
108+
if latest_stat:
109109
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
110110

111111
def on_train_epoch_start(self, trainer, *args, **kwargs):
112112
if self.logging_interval != 'step':
113113
interval = 'epoch' if self.logging_interval is None else 'any'
114114
latest_stat = self._extract_stats(trainer, interval)
115115

116-
if trainer.logger is not None and latest_stat:
116+
if latest_stat:
117117
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
118118

119119
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
@@ -190,5 +190,4 @@ def _should_log(trainer) -> bool:
190190
or trainer.should_stop
191191
)
192192

193-
should_log = should_log and not trainer.fast_dev_run
194193
return should_log

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
2121
"""
2222

23-
from copy import deepcopy
2423
import numbers
2524
import os
26-
from pathlib import Path
2725
import re
26+
from copy import deepcopy
27+
from pathlib import Path
2828
from typing import Any, Dict, Optional, Union
2929

3030
import numpy as np
@@ -224,7 +224,8 @@ def save_checkpoint(self, trainer, pl_module):
224224
global_step = trainer.global_step
225225

226226
if (
227-
self.save_top_k == 0 # no models are saved
227+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
228+
or self.save_top_k == 0 # no models are saved
228229
or self.period < 1 # no models are saved
229230
or (epoch + 1) % self.period # skip epoch
230231
or trainer.running_sanity_check # don't save anything during sanity check
@@ -478,14 +479,14 @@ def __resolve_ckpt_dir(self, trainer, pl_module):
478479
version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))
479480

480481
ckpt_path = os.path.join(
481-
save_dir, name, version, "checkpoints"
482+
save_dir, str(name), version, "checkpoints"
482483
)
483484
else:
484485
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
485486

486487
self.dirpath = ckpt_path
487488

488-
if trainer.is_global_zero:
489+
if not trainer.fast_dev_run and trainer.is_global_zero:
489490
self._fs.makedirs(self.dirpath, exist_ok=True)
490491

491492
def _add_backward_monitor_support(self, trainer):

pytorch_lightning/callbacks/progress.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import importlib
2323
import sys
2424

25-
2625
# check if ipywidgets is installed before importing tqdm.auto
2726
# to ensure it won't fail and a progress bar is displayed
2827
if importlib.util.find_spec('ipywidgets') is not None:
@@ -323,7 +322,7 @@ def on_epoch_start(self, trainer, pl_module):
323322
super().on_epoch_start(trainer, pl_module)
324323
total_train_batches = self.total_train_batches
325324
total_val_batches = self.total_val_batches
326-
if total_train_batches != float('inf') and not trainer.fast_dev_run:
325+
if total_train_batches != float('inf'):
327326
# val can be checked multiple times per epoch
328327
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
329328
total_val_batches = total_val_batches * val_checks_per_epoch

pytorch_lightning/trainer/connectors/debugging_connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1615
from typing import Union
17-
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info
16+
17+
from pytorch_lightning.loggers.base import DummyLogger
18+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
19+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1820

1921

2022
class DebuggingConnector:
@@ -54,11 +56,16 @@ def on_init_start(
5456
limit_train_batches = fast_dev_run
5557
limit_val_batches = fast_dev_run
5658
limit_test_batches = fast_dev_run
59+
self.trainer.max_steps = fast_dev_run
5760
self.trainer.num_sanity_val_steps = 0
5861
self.trainer.max_epochs = 1
62+
self.trainer.val_check_interval = 1.0
63+
self.trainer.check_val_every_n_epoch = 1
64+
self.trainer.logger = DummyLogger()
65+
5966
rank_zero_info(
6067
'Running in fast_dev_run mode: will run a full train,'
61-
f' val and test loop using {fast_dev_run} batch(es)'
68+
f' val and test loop using {fast_dev_run} batch(es).'
6269
)
6370

6471
self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')

pytorch_lightning/trainer/properties.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import os
1616
from abc import ABC
1717
from argparse import ArgumentParser, Namespace
18-
from typing import List, Optional, Type, TypeVar, Union, cast
18+
from typing import cast, List, Optional, Type, TypeVar, Union
1919

2020
from pytorch_lightning.accelerators.accelerator import Accelerator
21-
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBarBase
21+
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.core.optimizer import is_lightning_optimizer
2424
from pytorch_lightning.loggers.base import LightningLoggerBase
@@ -27,7 +27,7 @@
2727
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
2828
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
2929
from pytorch_lightning.trainer.states import TrainerState
30-
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn
30+
from pytorch_lightning.utilities import argparse_utils, HOROVOD_AVAILABLE, rank_zero_warn, TPU_AVAILABLE
3131
from pytorch_lightning.utilities.cloud_io import get_filesystem
3232
from pytorch_lightning.utilities.model_utils import is_overridden
3333

@@ -196,7 +196,7 @@ def enable_validation(self) -> bool:
196196
""" Check if we should run validation during training. """
197197
model_ref = self.model_connector.get_model()
198198
val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0
199-
return val_loop_enabled or self.fast_dev_run
199+
return val_loop_enabled
200200

201201
@property
202202
def default_root_dir(self) -> str:
@@ -218,18 +218,38 @@ def weights_save_path(self) -> str:
218218
return os.path.normpath(self._weights_save_path)
219219
return self._weights_save_path
220220

221+
@property
222+
def early_stopping_callback(self) -> Optional[EarlyStopping]:
223+
"""
224+
The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
225+
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
226+
"""
227+
callbacks = self.early_stopping_callbacks
228+
return callbacks[0] if len(callbacks) > 0 else None
229+
230+
@property
231+
def early_stopping_callbacks(self) -> List[EarlyStopping]:
232+
"""
233+
A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
234+
found in the Trainer.callbacks list.
235+
"""
236+
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]
237+
221238
@property
222239
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
223240
"""
224-
The first checkpoint callback in the Trainer.callbacks list, or ``None`` if
225-
no checkpoint callbacks exist.
241+
The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
242+
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
226243
"""
227244
callbacks = self.checkpoint_callbacks
228245
return callbacks[0] if len(callbacks) > 0 else None
229246

230247
@property
231248
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
232-
""" A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """
249+
"""
250+
A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
251+
found in the Trainer.callbacks list.
252+
"""
233253
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
234254

235255
def save_checkpoint(self, filepath, weights_only: bool = False):

0 commit comments

Comments
 (0)