From 016702f2ed93c5dfac9caa69db7c4e5468ac3a48 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 26 Dec 2020 18:54:45 +0530 Subject: [PATCH 1/8] Disable checkpointing, earlystopping and logger with fast_dev_run --- .../callbacks/gpu_stats_monitor.py | 3 +- pytorch_lightning/callbacks/lr_monitor.py | 5 +- pytorch_lightning/callbacks/progress.py | 3 +- .../trainer/connectors/debugging_connector.py | 15 +++- pytorch_lightning/trainer/properties.py | 26 +++++-- pytorch_lightning/trainer/training_loop.py | 5 +- tests/callbacks/test_early_stopping.py | 12 +-- .../test_checkpoint_callback_frequency.py | 47 +---------- tests/checkpointing/test_model_checkpoint.py | 16 ++-- tests/loggers/test_all.py | 8 +- tests/trainer/flags/test_fast_dev_run.py | 77 ++++++++++++++++++- 11 files changed, 133 insertions(+), 84 deletions(-) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index b083511392bb3..1403d0bdf2e31 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -24,7 +24,7 @@ import shutil import subprocess import time -from typing import List, Tuple, Dict +from typing import Dict, List, Tuple from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_only @@ -213,5 +213,4 @@ def _should_log(trainer) -> bool: or trainer.should_stop ) - should_log = should_log and not trainer.fast_dev_run return should_log diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 9799e0d3298d3..712695d69ecec 100755 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -105,7 +105,7 @@ def on_train_batch_start(self, trainer, *args, **kwargs): interval = 'step' if self.logging_interval is None else 'any' latest_stat = self._extract_stats(trainer, interval) - if trainer.logger is not None and latest_stat: + if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer, *args, **kwargs): @@ -113,7 +113,7 @@ def on_train_epoch_start(self, trainer, *args, **kwargs): interval = 'epoch' if self.logging_interval is None else 'any' latest_stat = self._extract_stats(trainer, interval) - if trainer.logger is not None and latest_stat: + if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: @@ -190,5 +190,4 @@ def _should_log(trainer) -> bool: or trainer.should_stop ) - should_log = should_log and not trainer.fast_dev_run return should_log diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 6582f16fd27be..3ed5c11fd75d7 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -22,7 +22,6 @@ import importlib import sys - # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed if importlib.util.find_spec('ipywidgets') is not None: @@ -323,7 +322,7 @@ def on_epoch_start(self, trainer, pl_module): super().on_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches - if total_train_batches != float('inf') and not trainer.fast_dev_run: + if total_train_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 61d7cbd189fde..f08d231363203 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.exceptions import MisconfigurationException from typing import Union -from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info + +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class DebuggingConnector: @@ -54,11 +56,18 @@ def on_init_start( limit_train_batches = fast_dev_run limit_val_batches = fast_dev_run limit_test_batches = fast_dev_run + self.trainer.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.max_epochs = 1 + self.trainer.val_check_interval = 1.0 + self.trainer.check_val_every_n_epoch = 1 + self.trainer.logger = None + self.trainer.callbacks = [ + c for c in self.trainer.callbacks if not isinstance(c, (EarlyStopping, ModelCheckpoint)) + ] rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' - f' val and test loop using {fast_dev_run} batch(es)' + f' val and test loop using {fast_dev_run} batch(es).' ) self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 355bbad3a037e..64c736f14d90a 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import List, Optional, Type, TypeVar, Union, cast +import inspect +import os +from typing import cast, List, Optional, Type, TypeVar, Union from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.loggers.base import LightningLoggerBase @@ -27,7 +27,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn +from pytorch_lightning.utilities import argparse_utils, HOROVOD_AVAILABLE, rank_zero_warn, TPU_AVAILABLE from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_utils import is_overridden @@ -196,7 +196,7 @@ def enable_validation(self) -> bool: """ Check if we should run validation during training. """ model_ref = self.model_connector.get_model() val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 - return val_loop_enabled or self.fast_dev_run + return val_loop_enabled @property def default_root_dir(self) -> str: @@ -218,6 +218,20 @@ def weights_save_path(self) -> str: return os.path.normpath(self._weights_save_path) return self._weights_save_path + @property + def early_stopping_callback(self) -> Optional[ModelCheckpoint]: + """ + The first early_stopping callback in the Trainer.callbacks list, or ``None`` if + no early_stopping callbacks exist. + """ + callbacks = self.early_stopping_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def early_stopping_callbacks(self) -> List[ModelCheckpoint]: + """ A list of all instances of EarlyStopping found in the Trainer.callbacks list. """ + return [c for c in self.callbacks if isinstance(c, EarlyStopping)] + @property def checkpoint_callback(self) -> Optional[ModelCheckpoint]: """ diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fe4525006ebb9..0271afe3c2d91 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -915,9 +915,8 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs or self.trainer.fast_dev_run is True: - if self.trainer.is_global_zero and self.trainer.logger is not None: - self.trainer.logger.save() + if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: + self.trainer.logger.save() def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 7cecefad03276..5c54f6a84805d 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -13,18 +13,17 @@ # limitations under the License. import os import pickle +from unittest import mock import cloudpickle import numpy as np import pytest import torch -from unittest import mock -from pytorch_lightning import _logger -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import _logger, seed_everything, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from tests.base import EvalModelTemplate, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import BoringModel, EvalModelTemplate class EarlyStoppingTestRestore(EarlyStopping): @@ -87,15 +86,18 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = EvalModelTemplate() + early_stop_callback = EarlyStopping() expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, - callbacks=[EarlyStopping()], + callbacks=[early_stop_callback], val_check_interval=1.0, max_epochs=expected_count, ) trainer.fit(model) + assert trainer.early_stopping_callback == early_stop_callback + assert trainer.early_stopping_callbacks == [early_stop_callback] assert len(trainer.dev_debugger.early_stopping_history) == expected_count diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 857877f8239ba..f9686dce159dd 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -17,55 +17,10 @@ import pytest import torch -from pytorch_lightning import Trainer, callbacks, seed_everything +from pytorch_lightning import callbacks, seed_everything, Trainer from tests.base import BoringModel -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_mc_called_on_fastdevrun(tmpdir): - seed_everything(1234) - - train_val_step_model = BoringModel() - - # fast dev run = called once - # train loop only, dict, eval result - trainer = Trainer(fast_dev_run=True) - trainer.fit(train_val_step_model) - - # checkpoint should have been called once with fast dev run - assert len(trainer.dev_debugger.checkpoint_callback_history) == 1 - - # ----------------------- - # also called once with no val step - # ----------------------- - class TrainingStepCalled(BoringModel): - def __init__(self): - super().__init__() - self.training_step_called = False - self.validation_step_called = False - self.test_step_called = False - - def training_step(self, batch, batch_idx): - self.training_step_called = True - return super().training_step(batch, batch_idx) - - train_step_only_model = TrainingStepCalled() - train_step_only_model.validation_step = None - - # fast dev run = called once - # train loop only, dict, eval result - trainer = Trainer(fast_dev_run=True) - trainer.fit(train_step_only_model) - - # make sure only training step was called - assert train_step_only_model.training_step_called - assert not train_step_only_model.validation_step_called - assert not train_step_only_model.test_step_called - - # checkpoint should have been called once with fast dev run - assert len(trainer.dev_debugger.checkpoint_callback_history) == 1 - - @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_mc_called(tmpdir): seed_everything(1234) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 106c34030051e..cb1d6c2575cf2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,29 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from argparse import Namespace import os +from pathlib import Path import pickle import platform import re -from argparse import Namespace -from pathlib import Path from unittest import mock from unittest.mock import Mock import cloudpickle +from omegaconf import Container, OmegaConf import pytest import torch import yaml -from omegaconf import Container, OmegaConf import pytorch_lightning as pl -import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel +import tests.base.develop_utils as tutils class LogInTwoMethods(BoringModel): @@ -896,7 +896,8 @@ def training_step(self, *args): ) trainer = Trainer( default_root_dir=tmpdir, - fast_dev_run=True, + limit_train_batches=1, + limit_val_batches=1, callbacks=[model_checkpoint], logger=False, weights_summary=None, @@ -922,7 +923,8 @@ def __init__(self, hparams): ) trainer = Trainer( default_root_dir=tmpdir, - fast_dev_run=True, + limit_train_batches=1, + limit_val_batches=1, callbacks=[model_checkpoint], logger=False, weights_summary=None, diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 89c731d432ee9..8fe407efebe54 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -20,7 +20,6 @@ import pytest -import tests.base.develop_utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import ( CometLogger, @@ -32,6 +31,7 @@ ) from pytorch_lightning.loggers.base import DummyExperiment from tests.base import BoringModel, EvalModelTemplate +import tests.base.develop_utils as tutils from tests.loggers.test_comet import _patch_comet_atexit from tests.loggers.test_mlflow import mock_mlflow_run_creation @@ -114,9 +114,9 @@ def log_metrics(self, metrics, step): trainer = Trainer( max_epochs=1, logger=logger, - limit_train_batches=0.2, - limit_val_batches=0.5, - fast_dev_run=True, + limit_train_batches=1, + limit_val_batches=1, + log_every_n_steps=1, default_root_dir=tmpdir, ) trainer.fit(model) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 00c62cdf48fce..2947f6da6f49c 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -1,14 +1,18 @@ +import os +from unittest import mock + import pytest + from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from tests.base import BoringModel @pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder']) def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): """ Test that tuner algorithms are skipped if fast dev run is enabled """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, @@ -19,3 +23,70 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' with pytest.warns(UserWarning, match=expected_message): trainer.tune(model) + + +@pytest.mark.parametrize('fast_dev_run', [1, 4]) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_mc_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): + """ + Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run + """ + class FastDevRunModel(BoringModel): + def __init__(self): + super().__init__() + self.training_step_called = False + self.validation_step_called = False + self.test_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + return super().validation_step(batch, batch_idx) + + def _make_fast_dev_run_assertions(trainer): + # there should be no logger with fast_dev_run + assert trainer.logger is None + assert len(trainer.dev_debugger.logged_metrics) == 0 + + # checkpoint and early stopping should not have been called with fast_dev_run + assert trainer.early_stopping_callback is None + assert trainer.checkpoint_callback is None + assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + assert len(trainer.dev_debugger.early_stopping_history) == 0 + + train_val_step_model = FastDevRunModel() + trainer_config = dict( + fast_dev_run=fast_dev_run, + logger=True, + log_every_n_steps=1, + callbacks=[ModelCheckpoint(), EarlyStopping()], + ) + + trainer = Trainer(**trainer_config) + results = trainer.fit(train_val_step_model) + assert results + + # make sure both training_step and validation_step were called + assert train_val_step_model.training_step_called + assert train_val_step_model.validation_step_called + + _make_fast_dev_run_assertions(trainer) + + # ----------------------- + # also called once with no val step + # ----------------------- + train_step_only_model = FastDevRunModel() + train_step_only_model.validation_step = None + + trainer = Trainer(**trainer_config) + results = trainer.fit(train_step_only_model) + assert results + + # make sure only training_step was called + assert train_step_only_model.training_step_called + assert not train_step_only_model.validation_step_called + + _make_fast_dev_run_assertions(trainer) From 544f38e3bd7dc5fbd365b6dedd787713fbea48a9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 26 Dec 2020 19:04:14 +0530 Subject: [PATCH 2/8] docs --- docs/source/debugging.rst | 7 ++++++- docs/source/trainer.rst | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 5eaf4303d3e4c..f3faa72f1e95e 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -28,13 +28,18 @@ The point is to detect any bugs in the training/validation loop without having t argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. testcode:: - + # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) # runs 7 train, val, test batches and program ends trainer = Trainer(fast_dev_run=7) +.. note:: + + This argument will disable tuner, checkpoint callbacks, early stopping callbacks, + loggers and logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. + ---------------- Inspect gradient norms diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 634a0c5d3d9dc..9efdd0bf09740 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -666,9 +666,9 @@ Under the hood the pseudocode looks like this when running *fast_dev_run* with a .. note:: This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will - disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be - used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't - disable anything. + disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like + ``LearningRateLogger`` and runs for only 1 epoch. This must be used only for debugging purposes. + ``limit_train/val/test_batches`` only limits the number of batches and won't disable anything. gpus ^^^^ From d06a7a09113b300d5f869ace28841170a14d7401 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 27 Dec 2020 16:24:29 +0530 Subject: [PATCH 3/8] chlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b9b705459510..9397535712002 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277)) From a647ad5d4a2ab446e076daf13f80bbb343df5e11 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 28 Dec 2020 23:51:09 +0530 Subject: [PATCH 4/8] disable callbacks and enable DummyLogger --- pytorch_lightning/callbacks/early_stopping.py | 23 ++++++-------- .../callbacks/model_checkpoint.py | 7 +++-- .../trainer/connectors/debugging_connector.py | 8 ++--- pytorch_lightning/trainer/properties.py | 22 +++++++++----- tests/trainer/flags/test_fast_dev_run.py | 30 +++++++++++-------- 5 files changed, 48 insertions(+), 42 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4125a924cb2c5..3e15d8462350c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -28,7 +28,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE class EarlyStopping(Callback): @@ -166,10 +166,10 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer, pl_module) def on_validation_epoch_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.fast_dev_run or trainer.running_sanity_check: return - if self._validate_condition_metric(trainer.logger_connector.callback_metrics): + if self._validate_condition_metric(trainer.callback_metrics): # turn off early stopping in on_train_epoch_end self.based_on_eval_results = True @@ -178,24 +178,19 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): if self.based_on_eval_results: return - # early stopping can also work in the train loop when there is no val loop - should_check_early_stop = False - - # fallback to monitor key in result dict - if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None: - should_check_early_stop = True - - if should_check_early_stop: - self._run_early_stopping_check(trainer, pl_module) + self._run_early_stopping_check(trainer, pl_module) def _run_early_stopping_check(self, trainer, pl_module): """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. """ - logs = trainer.logger_connector.callback_metrics + logs = trainer.callback_metrics - if not self._validate_condition_metric(logs): + if ( + trainer.fast_dev_run # disable early_stopping with fast_dev_run + or not self._validate_condition_metric(logs) # short circuit if metric not present + ): return # short circuit if metric not present current = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5a1079f8063f4..36d8f0278f34c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -224,7 +224,8 @@ def save_checkpoint(self, trainer, pl_module): global_step = trainer.global_step if ( - self.save_top_k == 0 # no models are saved + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check @@ -478,14 +479,14 @@ def __resolve_ckpt_dir(self, trainer, pl_module): version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name)) ckpt_path = os.path.join( - save_dir, name, version, "checkpoints" + save_dir, str(name), version, "checkpoints" ) else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") self.dirpath = ckpt_path - if trainer.is_global_zero: + if not trainer.fast_dev_run and trainer.is_global_zero: self._fs.makedirs(self.dirpath, exist_ok=True) def _add_backward_monitor_support(self, trainer): diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index f08d231363203..ecba35d5dbf55 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -14,7 +14,7 @@ from typing import Union -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -61,10 +61,8 @@ def on_init_start( self.trainer.max_epochs = 1 self.trainer.val_check_interval = 1.0 self.trainer.check_val_every_n_epoch = 1 - self.trainer.logger = None - self.trainer.callbacks = [ - c for c in self.trainer.callbacks if not isinstance(c, (EarlyStopping, ModelCheckpoint)) - ] + self.trainer.logger = DummyLogger() + rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' f' val and test loop using {fast_dev_run} batch(es).' diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 64c736f14d90a..5317593f05fff 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -219,31 +219,37 @@ def weights_save_path(self) -> str: return self._weights_save_path @property - def early_stopping_callback(self) -> Optional[ModelCheckpoint]: + def early_stopping_callback(self) -> Optional[EarlyStopping]: """ - The first early_stopping callback in the Trainer.callbacks list, or ``None`` if - no early_stopping callbacks exist. + The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` + callback in the Trainer.callbacks list, or ``None`` if it doesn't exist. """ callbacks = self.early_stopping_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def early_stopping_callbacks(self) -> List[ModelCheckpoint]: - """ A list of all instances of EarlyStopping found in the Trainer.callbacks list. """ + def early_stopping_callbacks(self) -> List[EarlyStopping]: + """ + A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` + found in the Trainer.callbacks list. + """ return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @property def checkpoint_callback(self) -> Optional[ModelCheckpoint]: """ - The first checkpoint callback in the Trainer.callbacks list, or ``None`` if - no checkpoint callbacks exist. + The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + callback in the Trainer.callbacks list, or ``None`` if it doesn't exist. """ callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property def checkpoint_callbacks(self) -> List[ModelCheckpoint]: - """ A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """ + """ + A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + found in the Trainer.callbacks list. + """ return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] def save_checkpoint(self, filepath, weights_only: bool = False): diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 2947f6da6f49c..8464efd468aaf 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -5,6 +5,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers.base import DummyLogger from tests.base import BoringModel @@ -27,7 +28,7 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): @pytest.mark.parametrize('fast_dev_run', [1, 4]) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_mc_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): +def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): """ Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run """ @@ -46,25 +47,30 @@ def validation_step(self, batch, batch_idx): self.validation_step_called = True return super().validation_step(batch, batch_idx) + checkpoint_callback = ModelCheckpoint() + early_stopping_callback = EarlyStopping() + trainer_config = dict( + fast_dev_run=fast_dev_run, + logger=True, + log_every_n_steps=1, + callbacks=[checkpoint_callback, early_stopping_callback], + ) + def _make_fast_dev_run_assertions(trainer): # there should be no logger with fast_dev_run - assert trainer.logger is None + assert isinstance(trainer.logger, DummyLogger) assert len(trainer.dev_debugger.logged_metrics) == 0 - # checkpoint and early stopping should not have been called with fast_dev_run - assert trainer.early_stopping_callback is None - assert trainer.checkpoint_callback is None + # checkpoint callback should not have been called with fast_dev_run + assert trainer.checkpoint_callback == checkpoint_callback + assert not os.path.exists(checkpoint_callback.dirpath) assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + + # early stopping should not have been called with fast_dev_run + assert trainer.early_stopping_callback == early_stopping_callback assert len(trainer.dev_debugger.early_stopping_history) == 0 train_val_step_model = FastDevRunModel() - trainer_config = dict( - fast_dev_run=fast_dev_run, - logger=True, - log_every_n_steps=1, - callbacks=[ModelCheckpoint(), EarlyStopping()], - ) - trainer = Trainer(**trainer_config) results = trainer.fit(train_val_step_model) assert results From 26a51e4d1b44769620b178ec345b2e89b45cbd27 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 3 Jan 2021 21:32:56 +0530 Subject: [PATCH 5/8] add log --- tests/trainer/flags/test_fast_dev_run.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 8464efd468aaf..3bfd549243f99 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -2,6 +2,7 @@ from unittest import mock import pytest +import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint @@ -40,6 +41,8 @@ def __init__(self): self.test_step_called = False def training_step(self, batch, batch_idx): + self.log('some_metric', torch.tensor(7.)) + self.logger.experiment.add_scaler('some_distribution', torch.randn(7) + batch_idx) self.training_step_called = True return super().training_step(batch, batch_idx) @@ -59,7 +62,7 @@ def validation_step(self, batch, batch_idx): def _make_fast_dev_run_assertions(trainer): # there should be no logger with fast_dev_run assert isinstance(trainer.logger, DummyLogger) - assert len(trainer.dev_debugger.logged_metrics) == 0 + assert len(trainer.dev_debugger.logged_metrics) == fast_dev_run # checkpoint callback should not have been called with fast_dev_run assert trainer.checkpoint_callback == checkpoint_callback From 653af657a368cac1175f19b4730d2d2499c8d0b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 3 Jan 2021 18:20:36 +0100 Subject: [PATCH 6/8] use dummy logger method --- tests/trainer/flags/test_fast_dev_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 3bfd549243f99..624b3cc6ac9c2 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -42,7 +42,7 @@ def __init__(self): def training_step(self, batch, batch_idx): self.log('some_metric', torch.tensor(7.)) - self.logger.experiment.add_scaler('some_distribution', torch.randn(7) + batch_idx) + self.logger.experiment.dummy_log('some_distribution', torch.randn(7) + batch_idx) self.training_step_called = True return super().training_step(batch, batch_idx) From 3c65ae5dc9130af9a7e4181ff9849e8c6424d027 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 4 Jan 2021 20:44:23 +0100 Subject: [PATCH 7/8] Apply suggestions from code review --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9397535712002..c4afeff4ff01d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + - Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277)) From 51635cbee4bb1196f426249e17bef7eb091f1def Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 5 Jan 2021 01:33:18 +0530 Subject: [PATCH 8/8] isort --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/trainer/properties.py | 4 ++-- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- tests/loggers/test_all.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 36d8f0278f34c..a578c1d697f8e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,11 +20,11 @@ """ -from copy import deepcopy import numbers import os -from pathlib import Path import re +from copy import deepcopy +from pathlib import Path from typing import Any, Dict, Optional, Union import numpy as np diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 5317593f05fff..614c863fa7256 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC -from argparse import ArgumentParser, Namespace import inspect import os +from abc import ABC +from argparse import ArgumentParser, Namespace from typing import cast, List, Optional, Type, TypeVar, Union from pytorch_lightning.accelerators.accelerator import Accelerator diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index cb1d6c2575cf2..99ed807f111e5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,29 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from argparse import Namespace import os -from pathlib import Path import pickle import platform import re +from argparse import Namespace +from pathlib import Path from unittest import mock from unittest.mock import Mock import cloudpickle -from omegaconf import Container, OmegaConf import pytest import torch import yaml +from omegaconf import Container, OmegaConf import pytorch_lightning as pl +import tests.base.develop_utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel -import tests.base.develop_utils as tutils class LogInTwoMethods(BoringModel): diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 8fe407efebe54..795b1a91e688e 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -20,6 +20,7 @@ import pytest +import tests.base.develop_utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import ( CometLogger, @@ -31,7 +32,6 @@ ) from pytorch_lightning.loggers.base import DummyExperiment from tests.base import BoringModel, EvalModelTemplate -import tests.base.develop_utils as tutils from tests.loggers.test_comet import _patch_comet_atexit from tests.loggers.test_mlflow import mock_mlflow_run_creation