From 104d796f6afcf28e7b2e0163297e91cac28b229b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Feb 2021 02:32:44 +0100 Subject: [PATCH 01/34] Update code Co-authored-by: EliaCereda --- pytorch_lightning/accelerators/accelerator.py | 4 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- .../callbacks/model_checkpoint.py | 1 + .../trainer/configuration_validator.py | 22 ++--- .../trainer/connectors/data_connector.py | 6 +- .../logger_connector/logger_connector.py | 6 +- pytorch_lightning/trainer/evaluation_loop.py | 8 +- pytorch_lightning/trainer/states.py | 5 +- pytorch_lightning/trainer/trainer.py | 85 ++++++++++--------- tests/trainer/test_trainer.py | 6 +- 10 files changed, 76 insertions(+), 69 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 84d53b5addd6b..3e185cabc84a2 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -311,7 +311,7 @@ def setup_optimizers(self, trainer): trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if trainer.testing: + if trainer.evaluating: return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module @@ -409,7 +409,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I @property def results(self) -> Any: """ - The results of the last training/testing run will be cached within the training type plugin. + The results of the last run will be cached within the training type plugin. In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 384ce9699f60e..d3a1adf2b41d2 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -155,7 +155,7 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return self._run_early_stopping_check(trainer, pl_module) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 461c211baab12..bc53a9f19dc1f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -235,6 +235,7 @@ def save_checkpoint(self, trainer, pl_module): if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.evaluating # disable checkpointing during validation and test or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9cb22f39b7228..220066bca1cfc 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -30,12 +30,11 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if not self.trainer.evaluating: self.__verify_train_loop_configuration(model) - self.__verify_eval_loop_configuration(model, 'validation') else: - # check test loop configuration - self.__verify_eval_loop_configuration(model, 'test') + # check evaluation loop configuration + self.__verify_eval_loop_configuration(model) def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -83,18 +82,15 @@ def __verify_train_loop_configuration(self, model): ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model, eval_loop_name): - step_name = f'{eval_loop_name}_step' - - # map the dataloader name - loader_name = f'{eval_loop_name}_dataloader' - if eval_loop_name == 'validation': - loader_name = 'val_dataloader' + def __verify_eval_loop_configuration(self, model): + stage = self.trainer._running_stage + step_name = f'{stage}_step' + loader_name = 'val_dataloader' if self.trainer.validating else f'{stage}_dataloader' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') + rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: - rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') + rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 6ff35aadc36a3..76d610a19df27 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -112,10 +112,8 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None: - # Todo: required argument `stage` is not used - - # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None: + # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8ebec3238e276..4aa1b45a5981c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -297,11 +297,11 @@ def get_evaluate_epoch_results(self): self.prepare_eval_loop_results() - # log results of test - if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test: + # log results of evaluation + if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') + print(f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS') pprint({ k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v for k, v in results.items() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 087741aa69c2b..6c460f4081a31 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -39,11 +39,11 @@ def on_trainer_init(self): self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False - # when .test() is called, it sets this - self.trainer.tested_ckpt_path = None + # .validate() and .test() set this when they load a checkpoint + self.trainer.evaluated_ckpt_path = None - # when true, prints test results - self.trainer.verbose_test = True + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 1758cb41ee780..8194fd990cad5 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -44,11 +44,14 @@ class RunningStage(LightningEnum): True """ TRAINING = 'train' - EVALUATING = 'eval' + VALIDATING = 'validation' TESTING = 'test' PREDICTING = 'predict' TUNING = 'tune' + def is_evaluating(self) -> bool: + return self in (self.VALIDATING, self.TESTING) + def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cf3bfd7a3e5a3..8744fa9b59f2f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -872,21 +872,20 @@ def test( datamodule: Optional[LightningDataModule] = None, ): r""" - - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. Args: ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. - + If ``None``, use the current weights of the model. Default to ``best``. datamodule: A instance of :class:`LightningDataModule`. model: The model to test. - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. - verbose: If True, prints the test results + verbose: If True, prints the test results. Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. @@ -894,31 +893,32 @@ def test( # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose + self.verbose_evaluate = verbose self._running_stage = RunningStage.TESTING - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' ) - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.lightning_module, datamodule, 'test') + model_provided = model is not None + model = model or self.lightning_module - if model is not None: - results = self.__test_given_model(model, test_dataloaders) - else: - results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule, 'test') + results = ( + self.__evaluate_given_model(model, test_dataloaders + if model_provided else + self.__evaluate_using_best_weights(model, ckpt_path, test_dataloaders) + ) - self.teardown('test') + self.teardown('test', model=model) self._running_stage = None return results - def __test_using_best_weights(self, ckpt_path, test_dataloaders): - model = self.lightning_module - + def __evaluate_using_best_weights(self, model, ckpt_path: Optional[str] = None, dataloaders: Union[DataLoader, List[DataLoader]]): # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( @@ -944,33 +944,32 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model.load_state_dict(ckpt['state_dict']) # attach dataloaders - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - # run tests + # run test self.tested_ckpt_path = ckpt_path results = self.fit(model) # teardown - if self.is_function_implemented('teardown'): - model_ref = self.lightning_module - model_ref.teardown('test') + if self.is_function_implemented('teardown', model=model): + model.teardown('test') return results - def __test_given_model(self, model, test_dataloaders): + def __evaluate_given_model(self, model, dataloaders: Union[DataLoader, List[DataLoader]]): # attach data - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) # run test # sets up testing so we short circuit to eval results = self.fit(model) # teardown - if self.is_function_implemented('teardown'): - model.teardown('test') + if self.is_function_implemented('teardown', model=model): + model.teardown(stage) return results @@ -1052,11 +1051,17 @@ def tune( def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = 'test' if self.evaluating else 'fit' + if self.datamodule is not None: - called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + called = { + 'fit': self.datamodule.has_setup_fit, + 'test': self.datamodule.has_setup_test, + }[stage_name] + if not called: self.datamodule.setup(stage_name) + self.setup(model, stage_name) model.setup(stage_name) @@ -1151,12 +1156,16 @@ def tuning(self, val: bool) -> None: self._running_stage = None @property - def evaluating(self) -> bool: - return self._running_stage == RunningStage.EVALUATING + def validating(self) -> bool: + return self._running_stage == RunningStage.VALIDATING - @evaluating.setter - def evaluating(self, val: bool) -> None: + @validating.setter + def validating(self, val: bool) -> None: if val: - self._running_stage = RunningStage.EVALUATING - elif self.evaluating: + self._running_stage = RunningStage.VALIDATING + elif self.validating: self._running_stage = None + + @property + def evaluating(self) -> bool: + return self._running_stage and self._running_stage.is_evaluating() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 167930425dab1..03e970c623538 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -726,12 +726,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): trainer.test(ckpt_path=ckpt_path) else: trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path is None + assert trainer.evaluated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -743,7 +743,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): )[0].absolute() ) trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + assert trainer.evaluated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): From 23c2d3ba91dfea16bea78f8a6a6718d1fd285dde Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Feb 2021 02:59:54 +0100 Subject: [PATCH 02/34] More property updates --- .../plugins/training_type/ddp_spawn.py | 8 ++++-- .../plugins/training_type/sharded.py | 2 +- .../plugins/training_type/sharded_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 10 +++++--- .../trainer/connectors/data_connector.py | 2 +- .../logger_connector/epoch_result_store.py | 4 +-- .../logger_connector/logger_connector.py | 6 ++--- .../trainer/connectors/model_connector.py | 1 - pytorch_lightning/trainer/evaluation_loop.py | 5 ++-- pytorch_lightning/trainer/trainer.py | 25 +++++++++++-------- tests/models/test_restore.py | 2 +- 11 files changed, 38 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ca25a6d8bc382..77ed11604dee0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -218,7 +218,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? - if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + if ( + not self.lightning_module.trainer.evaluating + and best_model_path is not None + and len(best_model_path) > 0 + ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) @@ -235,7 +239,7 @@ def __recover_child_process_weights(self, best_path, last_path): # todo, pass also best score # load last weights - if last_path is not None and not self.lightning_module.trainer.testing: + if last_path is not None and not self.lightning_module.trainer.evaluating: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index ad0ab693bee0d..66f2e4f56a230 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -36,7 +36,7 @@ def _reinit_optimizers_with_oss(self): def _wrap_optimizers(self): trainer = self.model.trainer - if trainer.testing is True: + if trainer.evaluating is True: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index a8d497cd119b0..c7b770316c6fa 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -32,7 +32,7 @@ def _reinit_optimizers_with_oss(self): def _wrap_optimizers(self): trainer = self.model.trainer - if trainer.testing: + if trainer.evaluating: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 692a4426a6ad6..e53bea1a64529 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -124,7 +124,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? - if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + if ( + not self.lightning_module.trainer.evaluating + and best_model_path is not None + and len(best_model_path) > 0 + ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) xm.save(self.lightning_module.state_dict(), last_path) @@ -214,7 +218,7 @@ def post_dispatch(self) -> None: # todo, pass also bets score # load last weights - if last_path and not self.lightning_module.trainer.testing: + if last_path and not self.lightning_module.trainer.evaluating: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -228,7 +232,7 @@ def __load_weights_on_main_process(self) -> None: # load weights if not interrupted # TODO: check for trainer reference - if on_colab_kaggle() and not model.trainer.testing: + if on_colab_kaggle() and not model.trainer.evaluating: self.load_spawn_weights(model) self._model = model diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 76d610a19df27..9e08cf031175f 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -81,7 +81,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) - self.attach_datamodule(model, datamodule, 'fit') + self.attach_datamodule(model, datamodule) def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index c435204107775..e35beb56310a4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -309,7 +309,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics = {} batch_pbar_metrics = {} batch_log_metrics = {} - is_train = self._stage in RunningStage.TRAINING + is_train = self._stage is RunningStage.TRAINING if not self._has_batch_loop_finished: # get pbar @@ -339,7 +339,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics.update(epoch_log_metrics) callback_metrics.update(forked_metrics) - if not is_train and self.trainer.testing: + if not is_train and self.trainer.evaluating: logger_connector.evaluation_callback_metrics.update(callback_metrics) # update callback_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 4aa1b45a5981c..8808edbaf0fff 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -332,7 +332,7 @@ def _track_callback_metrics(self, eval_results): flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.testing: + if self.trainer.evaluating: self.trainer.logger_connector.evaluation_callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks @@ -347,7 +347,7 @@ def _track_callback_metrics(self, eval_results): flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.testing: + if self.trainer.evaluating: self.trainer.logger_connector.evaluation_callback_metrics.update(flat) def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics): @@ -365,7 +365,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric callback_metrics.update(log_metrics) callback_metrics.update(prog_bar_metrics) self.trainer.logger_connector.callback_metrics.update(callback_metrics) - if self.trainer.testing: + if self.trainer.evaluating: self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics) if len(dataloader_result_metrics) > 0: diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 4a0c565d78be0..cdaab6248f006 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -35,5 +35,4 @@ def copy_trainer_model_properties(self, model): m._device_type = str(self.trainer._device_type) m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None - m.testing = self.trainer.testing m.precision = self.trainer.precision diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 6c460f4081a31..69155dd224d08 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -32,9 +32,9 @@ def __init__(self, trainer): self.num_dataloaders = None def on_trainer_init(self): - self.trainer.num_val_batches = [] self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] + self.trainer.num_val_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False @@ -46,7 +46,6 @@ def on_trainer_init(self): self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self, max_batches): - # select dataloaders model = self.trainer.lightning_module # select dataloaders @@ -154,7 +153,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module model_ref._results = Result() - if self.testing: + if self.trainer.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(args) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8744fa9b59f2f..7edd820bf0192 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -667,9 +667,8 @@ def run_train(self): self.train_loop.on_train_end() def run_evaluation(self, max_batches=None, on_epoch=False): - # used to know if we are logging for val, test + reset cached results - self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING + self._running_stage = RunningStage.TESTING if self.testing else RunningStage.VALIDATING self.logger_connector.reset() # bookkeeping @@ -907,18 +906,23 @@ def test( model = model or self.lightning_module # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule, 'test') + self.data_connector.attach_datamodule(model, datamodule) results = ( - self.__evaluate_given_model(model, test_dataloaders + self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else - self.__evaluate_using_best_weights(model, ckpt_path, test_dataloaders) + self.__evaluate_using_best_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) ) - self.teardown('test', model=model) + self.teardown('test') self._running_stage = None return results - def __evaluate_using_best_weights(self, model, ckpt_path: Optional[str] = None, dataloaders: Union[DataLoader, List[DataLoader]]): + def __evaluate_using_best_weights( + self, + model, + ckpt_path: Optional[str] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None + ): # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( @@ -957,8 +961,7 @@ def __evaluate_using_best_weights(self, model, ckpt_path: Optional[str] = None, return results - def __evaluate_given_model(self, model, dataloaders: Union[DataLoader, List[DataLoader]]): - + def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): # attach data if dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) @@ -969,7 +972,7 @@ def __evaluate_given_model(self, model, dataloaders: Union[DataLoader, List[Data # teardown if self.is_function_implemented('teardown', model=model): - model.teardown(stage) + model.teardown('test') return results @@ -1013,7 +1016,7 @@ def predict( if datamodule is not None: # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule, 'predict') + self.data_connector.attach_datamodule(model, datamodule) # attach data if dataloaders is not None: diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 7d6c104abbd57..11dd94a45964e 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -452,7 +452,7 @@ def on_train_start(self): # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model - new_trainer._running_stage = RunningStage.EVALUATING + new_trainer._running_stage = RunningStage.VALIDATING dataloader = self.train_dataloader() tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) From dd0913434960cda3da683ae3c64c9d15574b95a9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 02:18:17 +0100 Subject: [PATCH 03/34] Move properties. Introduce trainer._fitting --- pytorch_lightning/trainer/properties.py | 63 +++++++++++++++++++- pytorch_lightning/trainer/trainer.py | 77 +++---------------------- 2 files changed, 69 insertions(+), 71 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c061c6ef28d4c..37ab583ce89b1 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -31,7 +31,7 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import TrainerState, RunningStage from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, @@ -46,8 +46,10 @@ class TrainerProperties(ABC): _default_root_dir: str + _fitting: bool = False _lightning_optimizers = None _progress_bar_callback: ProgressBarBase + _running_stage: Optional[RunningStage] = None _state: TrainerState _weights_save_path: str @@ -412,6 +414,65 @@ def distributed_sampler_kwargs(self) -> Optional[dict]: if isinstance(self.training_type_plugin, ParallelPlugin): return self.training_type_plugin.distributed_sampler_kwargs + @property + def training(self) -> bool: + return self._running_stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TRAINING + elif self.training: + self._running_stage = None + + @property + def testing(self) -> bool: + return self._running_stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TESTING + elif self.testing: + self._running_stage = None + + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + + @property + def tuning(self) -> bool: + return self._running_stage == RunningStage.TUNING + + @tuning.setter + def tuning(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TUNING + elif self.tuning: + self._running_stage = None + + @property + def validating(self) -> bool: + return self._running_stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.VALIDATING + elif self.validating: + self._running_stage = None + + @property + def evaluating(self) -> bool: + return self._running_stage and self._running_stage.is_evaluating() + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7edd820bf0192..436360b09e8f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -304,7 +304,6 @@ def __init__( """ super().__init__() - self._running_stage = None distributed_backend = distributed_backend or accelerator @@ -445,11 +444,10 @@ def fit( """ # bookkeeping self._state = TrainerState.RUNNING - - # bookkeeping # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. if self._running_stage is None: self._running_stage = RunningStage.TRAINING + self._fitting = self.training # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -531,6 +529,7 @@ def fit( self._state = TrainerState.FINISHED self._running_stage = None + self._fitting = False return self.accelerator.results or 1 @@ -604,9 +603,6 @@ def run_train(self): self.run_sanity_check(self.lightning_module) - # set stage for logging - self._running_stage = RunningStage.TRAINING - self.checkpoint_connector.has_trained = False # enable train mode @@ -667,13 +663,9 @@ def run_train(self): self.train_loop.on_train_end() def run_evaluation(self, max_batches=None, on_epoch=False): - # used to know if we are logging for val, test + reset cached results - self._running_stage = RunningStage.TESTING if self.testing else RunningStage.VALIDATING + # reset cached results self.logger_connector.reset() - # bookkeeping - self.evaluation_loop.testing = self.testing - # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) @@ -895,6 +887,7 @@ def test( self.verbose_evaluate = verbose self._running_stage = RunningStage.TESTING + self._fitting = False # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: @@ -915,6 +908,7 @@ def test( self.teardown('test') self._running_stage = None + return results def __evaluate_using_best_weights( @@ -1008,6 +1002,7 @@ def predict( model = model or self.lightning_module self._running_stage = RunningStage.PREDICTING + self._fitting = False if dataloaders and datamodule: raise MisconfigurationException( @@ -1024,6 +1019,7 @@ def predict( self.model = model results = self.fit(model) + self._running_stage = None return results @@ -1113,62 +1109,3 @@ def call_hook(self, hook_name, *args, **kwargs): if not skip: self._cache_logged_metrics() return output - - @property - def training(self) -> bool: - return self._running_stage == RunningStage.TRAINING - - @training.setter - def training(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TRAINING - elif self.training: - self._running_stage = None - - @property - def testing(self) -> bool: - return self._running_stage == RunningStage.TESTING - - @testing.setter - def testing(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TESTING - elif self.testing: - self._running_stage = None - - @property - def predicting(self) -> bool: - return self._running_stage == RunningStage.PREDICTING - - @predicting.setter - def predicting(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.PREDICTING - elif self.predicting: - self._running_stage = None - - @property - def tuning(self) -> bool: - return self._running_stage == RunningStage.TUNING - - @tuning.setter - def tuning(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TUNING - elif self.tuning: - self._running_stage = None - - @property - def validating(self) -> bool: - return self._running_stage == RunningStage.VALIDATING - - @validating.setter - def validating(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.VALIDATING - elif self.validating: - self._running_stage = None - - @property - def evaluating(self) -> bool: - return self._running_stage and self._running_stage.is_evaluating() From 68469ed088f3363c55886d6202a43045d556a69d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 02:41:05 +0100 Subject: [PATCH 04/34] Use trainer.fitting --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/sharded.py | 3 +-- pytorch_lightning/plugins/training_type/sharded_spawn.py | 3 +-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 8 ++++---- pytorch_lightning/trainer/configuration_validator.py | 7 +++---- pytorch_lightning/trainer/deprecated_api.py | 1 - pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/trainer.py | 8 ++++---- tests/callbacks/test_progress_bar.py | 2 +- 12 files changed, 20 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3e185cabc84a2..2f11a432b7cd8 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -311,7 +311,7 @@ def setup_optimizers(self, trainer): trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if trainer.evaluating: + if not trainer.fitting: return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d3a1adf2b41d2..0d1ca66310e5a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -155,7 +155,7 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check or trainer.evaluating: + if trainer.running_sanity_check or not trainer.fitting: return self._run_early_stopping_check(trainer, pl_module) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bc53a9f19dc1f..acb3b35ce0bd7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -235,7 +235,7 @@ def save_checkpoint(self, trainer, pl_module): if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.evaluating # disable checkpointing during validation and test + or not trainer.fitting # only save during fit or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 77ed11604dee0..e5f75e4ab8fbc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -219,7 +219,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if ( - not self.lightning_module.trainer.evaluating + self.lightning_module.trainer.fitting and best_model_path is not None and len(best_model_path) > 0 ): @@ -239,7 +239,7 @@ def __recover_child_process_weights(self, best_path, last_path): # todo, pass also best score # load last weights - if last_path is not None and not self.lightning_module.trainer.evaluating: + if last_path is not None and self.lightning_module.trainer.fitting: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 66f2e4f56a230..29c960ae40efa 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -35,8 +35,7 @@ def _reinit_optimizers_with_oss(self): trainer.convert_to_lightning_optimizers() def _wrap_optimizers(self): - trainer = self.model.trainer - if trainer.evaluating is True: + if not self.model.trainer.fitting: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index c7b770316c6fa..d29c0744017d3 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -31,8 +31,7 @@ def _reinit_optimizers_with_oss(self): trainer.optimizers = optimizers def _wrap_optimizers(self): - trainer = self.model.trainer - if trainer.evaluating: + if not self.model.trainer.fitting: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index e53bea1a64529..a162747b763c9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -125,7 +125,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if ( - not self.lightning_module.trainer.evaluating + self.lightning_module.trainer.fitting and best_model_path is not None and len(best_model_path) > 0 ): @@ -218,7 +218,7 @@ def post_dispatch(self) -> None: # todo, pass also bets score # load last weights - if last_path and not self.lightning_module.trainer.evaluating: + if last_path and self.lightning_module.trainer.fitting: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -229,10 +229,10 @@ def post_dispatch(self) -> None: def __load_weights_on_main_process(self) -> None: model = self.lightning_module + assert hasattr(model, "trainer") # load weights if not interrupted - # TODO: check for trainer reference - if on_colab_kaggle() and not model.trainer.evaluating: + if on_colab_kaggle() and model.trainer.fiting: self.load_spawn_weights(model) self._model = model diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 220066bca1cfc..8b7ce13004636 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -24,16 +24,15 @@ def __init__(self, trainer): def verify_loop_configurations(self, model: LightningModule): r""" - Checks that the model is configured correctly before training or testing is started. + Checks that the model is configured correctly before the run is started. Args: model: The model to check the configuration. """ - if not self.trainer.evaluating: + if self.trainer.training: self.__verify_train_loop_configuration(model) - else: - # check evaluation loop configuration + elif self.trainer.evaluating: self.__verify_eval_loop_configuration(model) def __verify_train_loop_configuration(self, model): diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 46cfc545c889d..afa288c51da3a 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -22,7 +22,6 @@ class DeprecatedDistDeviceAttributes: _distrib_type: DistributedType _device_type: DeviceType - _running_stage: RunningStage num_gpus: int accelerator_connector: AcceleratorConnector diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 37ab583ce89b1..889b7aedc88d3 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -46,7 +46,6 @@ class TrainerProperties(ABC): _default_root_dir: str - _fitting: bool = False _lightning_optimizers = None _progress_bar_callback: ProgressBarBase _running_stage: Optional[RunningStage] = None @@ -56,6 +55,7 @@ class TrainerProperties(ABC): accelerator_connector: AcceleratorConnector callbacks: List[Callback] checkpoint_connector: CheckpointConnector + fitting: bool = False # to differentiate between .fit() validation and .validate() validation limit_val_batches: int logger: LightningLoggerBase logger_connector: LoggerConnector diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 436360b09e8f9..c7833d1d5f1c3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -447,7 +447,7 @@ def fit( # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. if self._running_stage is None: self._running_stage = RunningStage.TRAINING - self._fitting = self.training + self.fitting = self.training # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -529,7 +529,7 @@ def fit( self._state = TrainerState.FINISHED self._running_stage = None - self._fitting = False + self.fitting = False return self.accelerator.results or 1 @@ -887,7 +887,7 @@ def test( self.verbose_evaluate = verbose self._running_stage = RunningStage.TESTING - self._fitting = False + self.fitting = False # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: @@ -1002,7 +1002,7 @@ def predict( model = model or self.lightning_module self._running_stage = RunningStage.PREDICTING - self._fitting = False + self.fitting = False if dataloaders and datamodule: raise MisconfigurationException( diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f16d8afd9cffd..9fe97717ca404 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -426,7 +426,7 @@ def test_progress_bar_print(tqdm_write, tmpdir): @mock.patch('builtins.print') @mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): - """ Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """ + """ Test that printing in LightningModule goes through built-in print function when progress bar is disabled. """ model = PrintModel() bar = ProgressBar() trainer = Trainer( From 89aa994ba5b02660ca1c0ec5e1bb85d922632c93 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 03:49:49 +0100 Subject: [PATCH 05/34] Fix reset dataloaders --- pytorch_lightning/trainer/evaluation_loop.py | 4 +--- pytorch_lightning/trainer/predict_loop.py | 5 ++--- pytorch_lightning/trainer/trainer.py | 9 ++++++--- pytorch_lightning/trainer/training_loop.py | 4 +--- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 69155dd224d08..d07cfbff9230a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -56,9 +56,7 @@ def get_evaluation_dataloaders(self, max_batches): new_max_batches = self.trainer.num_test_batches else: # val - in_sanity_check = self.trainer.running_sanity_check - should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch - if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: + if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) dataloaders = self.trainer.val_dataloaders diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 6b801cc7f5dea..40507a1bc03f4 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -27,9 +27,8 @@ def on_trainer_init(self): self.trainer.num_predict_batches = [] def get_predict_dataloaders(self, max_batches): - # select dataloaders - model = self.trainer.lightning_module - self.trainer.reset_predict_dataloader(model) + self.trainer.reset_predict_dataloader(self.trainer.lightning_module) + dataloaders = self.trainer.predict_dataloaders if max_batches is None: max_batches = self.trainer.num_predict_batches diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a9b5d2efc78d7..0991680b018ae 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -640,6 +640,8 @@ def run_train(self): self.train_loop.on_train_end() def run_evaluation(self, max_batches=None, on_epoch=False): + assert self._running_stage.is_evaluating() + # reset cached results self.logger_connector.reset() @@ -744,8 +746,8 @@ def run_test(self): if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - # only load test dataloader for testing - # self.reset_test_dataloader(ref_model) + assert self.testing + with self.profiler.profile("run_test_evaluation"): eval_loop_results, _ = self.run_evaluation() @@ -807,7 +809,6 @@ def run_sanity_check(self, ref_model): # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: - self.reset_val_dataloader(ref_model) self.num_sanity_val_batches = [ min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches ] @@ -817,7 +818,9 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_start() # run eval step + self._running_stage = RunningStage.VALIDATING _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches) + self._running_stage = RunningStage.TRAINING # allow no returns from eval if eval_results is not None and len(eval_results) > 0: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 57ad3f6b06d36..702f8c15e0c5d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -477,7 +477,6 @@ def run_training_epoch(self): train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - should_check_val = False val_loop_called = False for batch_idx, (batch, is_last_batch) in train_dataloader: @@ -572,9 +571,8 @@ def run_training_epoch(self): self.check_early_stopping_callback(True) if should_check_val: + self.trainer._running_stage = RunningStage.VALIDATING self.trainer.run_evaluation(on_epoch=True) - - # reset stage to train self.trainer._running_stage = RunningStage.TRAINING # increment the global step once From 3c6e99c4528bdcfe9acb350ab2bf790272667705 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 03:50:46 +0100 Subject: [PATCH 06/34] Unused code --- pytorch_lightning/trainer/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0991680b018ae..c7dd04d7825e9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -469,10 +469,6 @@ def fit( # `run_predict` is the simplest to understand, use `Go to Definition` to read it :) # Search for `start_training` or `start_testing` or `start_predicting` in # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions. - self.accelerator.train_loop = self.run_train - self.accelerator.validation_loop = self.run_evaluation - self.accelerator.test_loop = self.run_evaluation - self.accelerator.predict_loop = self.run_predict # ---------------------------- # TRAIN From 9ba12d7952c22658ee66e50eb4d0c5a4081e680f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 03:51:45 +0100 Subject: [PATCH 07/34] RunningStage.SANITY_CHECKING --- pytorch_lightning/callbacks/early_stopping.py | 6 ++--- .../callbacks/model_checkpoint.py | 2 +- pytorch_lightning/callbacks/progress.py | 2 +- .../logger_connector/logger_connector.py | 4 ++-- pytorch_lightning/trainer/deprecated_api.py | 8 +++++++ pytorch_lightning/trainer/evaluation_loop.py | 3 +-- pytorch_lightning/trainer/properties.py | 23 ++++++++++++++----- pytorch_lightning/trainer/states.py | 1 + pytorch_lightning/trainer/trainer.py | 7 ++++-- pytorch_lightning/utilities/debugging.py | 2 +- 10 files changed, 40 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 3e04a9f762073..93de7c824bcb3 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -132,12 +132,12 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check or not trainer.fitting: + if trainer.sanity_checking or not trainer.fitting: return - self._run_early_stopping_check(trainer, pl_module) + self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer, pl_module): + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e4a216fc76fd2..a6182fd7c4e99 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -219,7 +219,7 @@ def save_checkpoint(self, trainer, pl_module): or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch - or trainer.running_sanity_check # don't save anything during sanity check + or trainer.sanity_checking # don't save anything during sanity check or self._last_global_step_saved == global_step # already saved at the last step ): return diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 2f133eaccf512..a4d9a87894b0b 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -412,7 +412,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - if not trainer.running_sanity_check: + if not trainer.sanity_checking: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, self.total_val_batches) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8808edbaf0fff..8604f8535f222 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -289,7 +289,7 @@ def prepare_eval_loop_results(self): self.add_to_eval_loop_results(dl_idx, has_been_initialized) def get_evaluate_epoch_results(self): - if not self.trainer.running_sanity_check: + if not self.trainer.sanity_checking: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() if len(metrics_to_log) > 0: @@ -372,7 +372,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric self.eval_loop_results.append(dataloader_result_metrics) def __process_eval_epoch_end_results_and_log_legacy(self, eval_results): - if self.trainer.running_sanity_check: + if self.trainer.sanity_checking: return if eval_results is not None and len(eval_results) > 0: diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index afa288c51da3a..ff3bc2a876629 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -137,6 +137,7 @@ class DeprecatedTrainerAttributes: accelerator: Accelerator lightning_module = LightningModule + sanity_checking: bool @property def accelerator_backend(self) -> Accelerator: @@ -152,3 +153,10 @@ def get_model(self) -> LightningModule: " and will be removed in v1.4.", DeprecationWarning ) return self.lightning_module + + @property + def running_sanity_check(self) -> bool: + rank_zero_warn( + "TODO", DeprecationWarning + ) + return self.sanity_checking diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d07cfbff9230a..03df03f8c77af 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -37,7 +37,6 @@ def on_trainer_init(self): self.trainer.num_val_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None - self.trainer.running_sanity_check = False # .validate() and .test() set this when they load a checkpoint self.trainer.evaluated_ckpt_path = None @@ -318,7 +317,7 @@ def on_evaluation_epoch_end(self, *args, **kwargs): self.trainer.call_hook('on_epoch_end') def log_evaluation_step_metrics(self, output, batch_idx): - if self.trainer.running_sanity_check: + if self.trainer.sanity_checking: return step_log_metrics = {} diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 889b7aedc88d3..8767c1c027bc5 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -31,7 +31,7 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector -from pytorch_lightning.trainer.states import TrainerState, RunningStage +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, @@ -416,7 +416,7 @@ def distributed_sampler_kwargs(self) -> Optional[dict]: @property def training(self) -> bool: - return self._running_stage == RunningStage.TRAINING + return self._running_stage is RunningStage.TRAINING @training.setter def training(self, val: bool) -> None: @@ -427,7 +427,7 @@ def training(self, val: bool) -> None: @property def testing(self) -> bool: - return self._running_stage == RunningStage.TESTING + return self._running_stage is RunningStage.TESTING @testing.setter def testing(self, val: bool) -> None: @@ -438,7 +438,7 @@ def testing(self, val: bool) -> None: @property def predicting(self) -> bool: - return self._running_stage == RunningStage.PREDICTING + return self._running_stage is RunningStage.PREDICTING @predicting.setter def predicting(self, val: bool) -> None: @@ -449,7 +449,7 @@ def predicting(self, val: bool) -> None: @property def tuning(self) -> bool: - return self._running_stage == RunningStage.TUNING + return self._running_stage is RunningStage.TUNING @tuning.setter def tuning(self, val: bool) -> None: @@ -460,7 +460,7 @@ def tuning(self, val: bool) -> None: @property def validating(self) -> bool: - return self._running_stage == RunningStage.VALIDATING + return self._running_stage is RunningStage.VALIDATING @validating.setter def validating(self, val: bool) -> None: @@ -473,6 +473,17 @@ def validating(self, val: bool) -> None: def evaluating(self) -> bool: return self._running_stage and self._running_stage.is_evaluating() + @property + def sanity_checking(self) -> bool: + return self._running_stage is RunningStage.SANITY_CHECKING + + @sanity_checking.setter + def sanity_checking(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.SANITY_CHECKING + elif self.sanity_checking: + self._running_stage = None + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 8194fd990cad5..1fe52c3e8b5cd 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -44,6 +44,7 @@ class RunningStage(LightningEnum): True """ TRAINING = 'train' + SANITY_CHECKING = "sanity_check" VALIDATING = 'validation' TESTING = 'test' PREDICTING = 'predict' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c7dd04d7825e9..35d76ac58ba80 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -805,12 +805,14 @@ def run_sanity_check(self, ref_model): # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: + stage = self._running_stage + self.sanity_checking = True + self.num_sanity_val_batches = [ min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches ] # hook and callback - self.running_sanity_check = True self.on_sanity_check_start() # run eval step @@ -828,7 +830,8 @@ def run_sanity_check(self, ref_model): self.logger_connector.callback_metrics = callback_metrics self.on_sanity_check_end() - self.running_sanity_check = False + + self._running_stage = stage def test( self, diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 65cf4472d156c..56833fd03735a 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -139,7 +139,7 @@ def track_lr_schedulers_update( @enabled_only def track_eval_loss_history(self, batch_idx, dataloader_idx, output): loss_dict = { - 'sanity_check': self.trainer.running_sanity_check, + 'sanity_check': self.trainer.sanity_checking, 'dataloader_idx': dataloader_idx, 'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, From f3d16a4fae39d158561067b226a34aa358cb2304 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 04:01:48 +0100 Subject: [PATCH 08/34] Use setters --- pytorch_lightning/trainer/trainer.py | 14 ++++++-------- pytorch_lightning/trainer/training_loop.py | 6 +++--- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35d76ac58ba80..cd8effc9646ed 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -423,7 +423,7 @@ def fit( self._state = TrainerState.RUNNING # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. if self._running_stage is None: - self._running_stage = RunningStage.TRAINING + self.training = True self.fitting = self.training # set local properties on the model @@ -636,7 +636,7 @@ def run_train(self): self.train_loop.on_train_end() def run_evaluation(self, max_batches=None, on_epoch=False): - assert self._running_stage.is_evaluating() + assert self.evaluating or self.sanity_checking # reset cached results self.logger_connector.reset() @@ -816,9 +816,7 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_start() # run eval step - self._running_stage = RunningStage.VALIDATING _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches) - self._running_stage = RunningStage.TRAINING # allow no returns from eval if eval_results is not None and len(eval_results) > 0: @@ -865,7 +863,7 @@ def test( # -------------------- self.verbose_evaluate = verbose - self._running_stage = RunningStage.TESTING + self.testing = True self.fitting = False # If you supply a datamodule you can't supply test_dataloaders @@ -886,7 +884,7 @@ def test( ) self.teardown('test') - self._running_stage = None + self.testing = False return results @@ -980,7 +978,7 @@ def predict( model = model or self.lightning_module - self._running_stage = RunningStage.PREDICTING + self.predicting = True self.fitting = False if dataloaders and datamodule: @@ -999,7 +997,7 @@ def predict( self.model = model results = self.fit(model) - self._running_stage = None + self.predicting = False return results diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 702f8c15e0c5d..daa0b76b0c1f7 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -516,7 +516,7 @@ def run_training_epoch(self): val_loop_called = True # reset stage to train - self.trainer._running_stage = RunningStage.TRAINING + self.trainer.training = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -571,9 +571,9 @@ def run_training_epoch(self): self.check_early_stopping_callback(True) if should_check_val: - self.trainer._running_stage = RunningStage.VALIDATING + self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) - self.trainer._running_stage = RunningStage.TRAINING + self.trainer.training = True # increment the global step once # progress global step according to grads progress From 0697c3e92a435898550770dc73f8f0183e554622 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 05:10:12 +0100 Subject: [PATCH 09/34] Fix bugs --- .../callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/overrides/base.py | 8 ++++---- .../plugins/training_type/rpc_sequential.py | 3 ++- pytorch_lightning/trainer/evaluation_loop.py | 18 ++++++++++-------- pytorch_lightning/trainer/trainer.py | 18 ++++++++++-------- pytorch_lightning/trainer/training_loop.py | 5 ++--- tests/callbacks/test_callbacks.py | 2 +- 7 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a6182fd7c4e99..3685dc2914f18 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -215,11 +215,11 @@ def save_checkpoint(self, trainer, pl_module): if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or not trainer.fitting # only save during fit + or not trainer.fitting # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch - or trainer.sanity_checking # don't save anything during sanity check or self._last_global_step_saved == global_step # already saved at the last step ): return diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index c0b691bb07cb8..bf7be4ba774b2 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -45,7 +45,7 @@ def __init__(self, pl_module: LightningModule): def forward(self, *inputs, **kwargs): running_stage = self.module.running_stage - if running_stage == RunningStage.TRAINING: + if running_stage is RunningStage.TRAINING: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as @@ -56,15 +56,15 @@ def forward(self, *inputs, **kwargs): self.module.trainer.model.require_backward_grad_sync = False warn_if_output_is_none(output, "training_step") - elif running_stage == RunningStage.TESTING: + elif running_stage is RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - elif running_stage == RunningStage.EVALUATING: + elif running_stage in (RunningStage.VALIDATING, RunningStage.SANITY_CHECKING): output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - elif running_stage == RunningStage.PREDICTING: + elif running_stage is RunningStage.PREDICTING: output = self.module.predict(*inputs, **kwargs) warn_if_output_is_none(output, "predict") diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 3878aa9db3ea4..df61006a24406 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -208,7 +208,8 @@ def _skip_init_connections(self): Returns: Whether to skip initialization """ - return torch_distrib.is_initialized() and self.lightning_module.running_stage == RunningStage.TESTING + # TODO: should this use trainer.fitting? same for other occurrences in this file + return torch_distrib.is_initialized() and self.lightning_module.running_stage is RunningStage.TESTING def init_model_parallel_groups(self): num_model_parallel = 1 # TODO currently no support for vertical model parallel diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 03df03f8c77af..b45d58c836648 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -44,7 +44,7 @@ def on_trainer_init(self): # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True - def get_evaluation_dataloaders(self, max_batches): + def get_evaluation_dataloaders(self): model = self.trainer.lightning_module # select dataloaders @@ -52,18 +52,20 @@ def get_evaluation_dataloaders(self, max_batches): self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders - new_max_batches = self.trainer.num_test_batches + max_batches = self.trainer.num_test_batches else: # val if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) - + if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) + for val_batches in self.trainer.num_val_batches + ] + max_batches = self.trainer.num_sanity_val_batches + else: + max_batches = self.trainer.num_val_batches dataloaders = self.trainer.val_dataloaders - new_max_batches = self.trainer.num_val_batches - - if max_batches is None: - max_batches = new_max_batches - return dataloaders, max_batches def should_skip_evaluation(self, max_batches): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cd8effc9646ed..b9e15ef0f3838 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,6 +15,7 @@ import warnings from itertools import count from pathlib import Path +from traceback import print_exc from typing import Dict, Iterable, List, Optional, Union import torch @@ -52,7 +53,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -631,18 +632,23 @@ def run_train(self): self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() + except (RuntimeError, AssertionError): + # if an exception is raised, the finally block is executed and can hide the actual exception + # that was initially raised if `on_train_end` also raises an exception. we want to avoid that + # for assertions and other runtime errors so we aren't misled while debugging + print_exc() finally: # hook self.train_loop.on_train_end() - def run_evaluation(self, max_batches=None, on_epoch=False): + def run_evaluation(self, on_epoch=False): assert self.evaluating or self.sanity_checking # reset cached results self.logger_connector.reset() # prepare dataloaders - dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() # check if we want to skip this evaluation if self.evaluation_loop.should_skip_evaluation(max_batches): @@ -808,15 +814,11 @@ def run_sanity_check(self, ref_model): stage = self._running_stage self.sanity_checking = True - self.num_sanity_val_batches = [ - min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches - ] - # hook and callback self.on_sanity_check_start() # run eval step - _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches) + _, eval_results = self.run_evaluation() # allow no returns from eval if eval_results is not None and len(eval_results) > 0: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index daa0b76b0c1f7..7bd754af89709 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -512,11 +512,10 @@ def run_training_epoch(self): # ----------------------------------------- should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: + self.trainer.validating = True self.trainer.run_evaluation() - val_loop_called = True - - # reset stage to train self.trainer.training = True + val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 379bc79263a6e..56dadbc761d91 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,7 +19,7 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system(torch_save, tmpdir): +def test_trainer_callback_system(_, tmpdir): """Test the callback system.""" model = BoringModel() From 39686ae6de84ac8ac05ccc6ffaed45f2efa3bb2e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 06:21:14 +0100 Subject: [PATCH 10/34] Fix bugs --- .../trainer/connectors/optimizer_connector.py | 2 +- pytorch_lightning/trainer/data_loading.py | 6 +- pytorch_lightning/trainer/trainer.py | 9 ++- pytorch_lightning/trainer/training_loop.py | 5 +- .../test_eval_loop_dict_return.py | 5 +- tests/trainer/test_dataloaders.py | 72 ++++++------------- tests/trainer/test_trainer.py | 5 +- 7 files changed, 44 insertions(+), 60 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 1a1a992758dc8..a50603bb58dbf 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -51,7 +51,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): ) if monitor_val is None: if lr_scheduler.get('strict', True): - avail_metrics = self.trainer.logger_connector.callback_metrics.keys() + avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys()) raise MisconfigurationException( f'ReduceLROnPlateau conditioned on metric {monitor_key}' f' which is not available. Available metrics are: {avail_metrics}.' diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 06a3da750032c..95bd8b3f8cc44 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -372,8 +372,7 @@ def reset_test_dataloader(self, model) -> None: has_loader = is_overridden('test_dataloader', model) has_step = is_overridden('test_step', model) if has_loader and has_step: - self.num_test_batches, self.test_dataloaders =\ - self._reset_eval_dataloader(model, 'test') + self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(model, 'test') def reset_predict_dataloader(self, model) -> None: """Resets the predict dataloader and determines the number of batches. @@ -383,8 +382,7 @@ def reset_predict_dataloader(self, model) -> None: """ has_loader = is_overridden('predict_dataloader', model) if has_loader: - self.num_predict_batches, self.predict_dataloaders =\ - self._reset_eval_dataloader(model, 'predict') + self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: """Handles downloading data in the GPU or TPU case. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b9e15ef0f3838..1994d4947c811 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -642,7 +642,12 @@ def run_train(self): self.train_loop.on_train_end() def run_evaluation(self, on_epoch=False): - assert self.evaluating or self.sanity_checking + if not (self.evaluating or self.sanity_checking): + rank_zero_warn( + f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}." + " This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning + ) + self.validating = True # reset cached results self.logger_connector.reset() @@ -925,7 +930,7 @@ def __evaluate_using_best_weights( self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) # run test - self.tested_ckpt_path = ckpt_path + self.evaluated_ckpt_path = ckpt_path results = self.fit(model) # teardown diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7bd754af89709..8f777321ae7cd 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -123,7 +123,6 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): def on_train_end(self): if self._teardown_already_run: return - self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates @@ -148,6 +147,10 @@ def on_train_end(self): # give accelerators a chance to finish self.trainer.accelerator.on_train_end() + # reset bookkeeping + self.trainer._running_stage = None + self.trainer.fitting = False + def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: diff --git a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py index 87cab653de6aa..2aac7354c38f6 100644 --- a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py @@ -14,6 +14,8 @@ """ Tests to ensure that the training loop works with a dict """ +import pytest + from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule from tests.helpers.deterministic_model import DeterministicModel @@ -44,7 +46,8 @@ def backward(self, loss, optimizer, optimizer_idx): # out are the results of the full loop # eval_results are output of _evaluate - out, eval_results = trainer.run_evaluation() + with pytest.warns(RuntimeWarning, match="the running stage is set to None"): + out, eval_results = trainer.run_evaluation() assert len(out) == 1 assert len(eval_results) == 0 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index fe07e41d20b4c..e163559f770c3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -600,24 +600,20 @@ def test_error_on_zero_len_dataloader(tmpdir): @pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific')) +@pytest.mark.parametrize('stage', ('train', 'test', 'val')) @patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) -def test_warning_with_few_workers(mock, tmpdir, ckpt_path): +def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ - model = EvalModelTemplate() + model = BoringModel() - # logger file to get meta - train_dl = model.dataloader(train=True) + train_dl = model.train_dataloader() train_dl.num_workers = 0 - val_dl = model.dataloader(train=False) + val_dl = model.val_dataloader() val_dl.num_workers = 0 - train_dl = model.dataloader(train=False) - train_dl.num_workers = 0 - - fit_options = dict(train_dataloader=train_dl, val_dataloaders=val_dl) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -625,30 +621,22 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): limit_train_batches=0.2, ) - # fit model with pytest.warns( - UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' + UserWarning, + match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): - trainer.fit(model, **fit_options) - - with pytest.warns( - UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' - ): - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path) - with pytest.warns( - UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' - ): - trainer.test(**test_options) + if stage == 'test': + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path + trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) + else: + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) @pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific')) +@pytest.mark.parametrize('stage', ('train', 'test', 'val')) @patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) -def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): +def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() @@ -658,10 +646,6 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders - # logger file to get meta - train_dl = model.dataloader(train=True) - train_dl.num_workers = 0 - val_dl = model.dataloader(train=False) val_dl.num_workers = 0 @@ -672,7 +656,6 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] - fit_options = dict(train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -680,24 +663,15 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): limit_train_batches=0.2, ) - # fit model - with pytest.warns( - UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' - ): - trainer.fit(model, **fit_options) - - with pytest.warns( - UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' - ): - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) with pytest.warns( - UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' + UserWarning, + match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): - trainer.test(**test_options) + if stage == 'test': + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path + trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) + else: + trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) def test_warning_with_iterable_dataset_and_len(tmpdir): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9c02634db024c..d48e86e90c535 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -446,10 +446,11 @@ def mock_save_function(filepath, *args): monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, - verbose=1 + verbose=True ) checkpoint_callback.save_function = mock_save_function trainer = Trainer() + trainer.fitting = True # emulate callback's calls during the training for i, loss in enumerate(losses): @@ -690,7 +691,7 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) -def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): +def test_checkpoint_path(tmpdir, ckpt_path, save_top_k): hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) From 0ed386b42361afde513b743a10e474562fec4312 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 17:50:10 +0100 Subject: [PATCH 11/34] TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING} --- pytorch_lightning/accelerators/accelerator.py | 3 +- pytorch_lightning/callbacks/early_stopping.py | 3 +- .../callbacks/model_checkpoint.py | 3 +- .../plugins/training_type/ddp_spawn.py | 5 +-- .../plugins/training_type/sharded.py | 3 +- .../plugins/training_type/sharded_spawn.py | 3 +- .../plugins/training_type/tpu_spawn.py | 7 ++-- .../logger_connector/logger_connector.py | 9 +++-- pytorch_lightning/trainer/properties.py | 11 ++++-- pytorch_lightning/trainer/states.py | 32 ++++++++++++----- pytorch_lightning/trainer/trainer.py | 36 ++++++++++--------- pytorch_lightning/trainer/training_loop.py | 1 - 12 files changed, 77 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2f11a432b7cd8..67c9f89caa86d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -20,6 +20,7 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -311,7 +312,7 @@ def setup_optimizers(self, trainer): trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if not trainer.fitting: + if trainer.state is not TrainerState.FITTING: return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 93de7c824bcb3..19b565a5fc1c1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -23,6 +23,7 @@ import torch from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -132,7 +133,7 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.sanity_checking or not trainer.fitting: + if trainer.state is not TrainerState.FITTING or trainer.sanity_checking: return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3685dc2914f18..557c1f5cf2217 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -213,9 +213,10 @@ def save_checkpoint(self, trainer, pl_module): epoch = trainer.current_epoch global_step = trainer.global_step + from pytorch_lightning.trainer.states import TrainerState if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or not trainer.fitting # don't save anything during non-fit + or trainer.state is not TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index e5f75e4ab8fbc..d07006ef4fb6d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -27,6 +27,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -219,7 +220,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if ( - self.lightning_module.trainer.fitting + self.lightning_module.trainer.state is TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0 ): @@ -239,7 +240,7 @@ def __recover_child_process_weights(self, best_path, last_path): # todo, pass also best score # load last weights - if last_path is not None and self.lightning_module.trainer.fitting: + if last_path is not None and self.lightning_module.trainer.state is TrainerState.FITTING: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 29c960ae40efa..9eb3d168263c1 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -3,6 +3,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only if _FAIRSCALE_AVAILABLE: @@ -35,7 +36,7 @@ def _reinit_optimizers_with_oss(self): trainer.convert_to_lightning_optimizers() def _wrap_optimizers(self): - if not self.model.trainer.fitting: + if self.model.trainer.state is not TrainerState.FITTING: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index d29c0744017d3..4df0140a26b9f 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -2,6 +2,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only if _FAIRSCALE_AVAILABLE: @@ -31,7 +32,7 @@ def _reinit_optimizers_with_oss(self): trainer.optimizers = optimizers def _wrap_optimizers(self): - if not self.model.trainer.fitting: + if self.model.trainer.state is not TrainerState.FITTING: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a162747b763c9..516b96b082272 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -9,6 +9,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -125,7 +126,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if ( - self.lightning_module.trainer.fitting + self.lightning_module.trainer.state is TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0 ): @@ -218,7 +219,7 @@ def post_dispatch(self) -> None: # todo, pass also bets score # load last weights - if last_path and self.lightning_module.trainer.fitting: + if last_path and model.trainer.state is not TrainerState.FITTING: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -232,7 +233,7 @@ def __load_weights_on_main_process(self) -> None: assert hasattr(model, "trainer") # load weights if not interrupted - if on_colab_kaggle() and model.trainer.fiting: + if on_colab_kaggle() and model.trainer.state is TrainerState.FITTING: self.load_spawn_weights(model) self._model = model diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8604f8535f222..4e70d683f5001 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -24,7 +24,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType, flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -298,7 +298,12 @@ def get_evaluate_epoch_results(self): self.prepare_eval_loop_results() # log results of evaluation - if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate: + if ( + self.trainer.state is not TrainerState.FITTING + and self.trainer.evaluating + and self.trainer.is_global_zero + and self.trainer.verbose_evaluate + ): print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): print(f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS') diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 8767c1c027bc5..d0a8e2328a873 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -55,7 +55,6 @@ class TrainerProperties(ABC): accelerator_connector: AcceleratorConnector callbacks: List[Callback] checkpoint_connector: CheckpointConnector - fitting: bool = False # to differentiate between .fit() validation and .validate() validation limit_val_batches: int logger: LightningLoggerBase logger_connector: LoggerConnector @@ -170,6 +169,14 @@ def progress_bar_metrics(self, x: dict) -> None: def state(self) -> TrainerState: return self._state + @state.setter + def state(self, state: TrainerState) -> None: + self._state = state + + @property + def interrupted(self) -> bool: + return self._state is TrainerState.INTERRUPTED + @property def is_global_zero(self) -> bool: return self.global_rank == 0 @@ -471,7 +478,7 @@ def validating(self, val: bool) -> None: @property def evaluating(self) -> bool: - return self._running_stage and self._running_stage.is_evaluating() + return self._running_stage and self._running_stage.evaluating() @property def sanity_checking(self) -> bool: diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 1fe52c3e8b5cd..e00b78a8191b3 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -20,40 +20,56 @@ class TrainerState(LightningEnum): - """ State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer` - to indicate what is currently or was executed. + """ State for the :class:`~pytorch_lightning.trainer.trainer.Trainer` + to indicate what is currently or was executed. It follows the user-called + functions such as `trainer.fit()` and `trainer.test(). >>> # you can compare the type with a string - >>> TrainerState.RUNNING == 'RUNNING' + >>> TrainerState.FITTING == 'FITTING' True >>> # which is case insensitive >>> TrainerState.FINISHED == 'finished' True """ - INITIALIZING = 'INITIALIZING' - RUNNING = 'RUNNING' + INITIALIZING = 'INITIALIZING' # trainer creation + FITTING = 'FITTING' # trainer.fit() + VALIDATING = 'VALIDATING' # trainer.validate() + TESTING = 'TESTING' # trainer.test() + PREDICTING = 'PREDICTING' # trainer.predict() + TUNING = 'TUNING' # trainer.tune() FINISHED = 'FINISHED' INTERRUPTED = 'INTERRUPTED' + def stopped(self) -> bool: + return self in (self.FINISHED, self.INTERRUPTED) + def running(self) -> bool: + return self in (self.FITTING, self.VALIDATING, self.TESTING, self.PREDICTING, self.TUNING) + +# class RunningStage(LightningEnum): - """Type of train phase. + """Current running stage. + + This stage complements :class:`TrainerState` for example to indicate that + `RunningStage.VALIDATING` will be set both during `TrainerState.FITTING` + and `TrainerState.VALIDATING`. It follows the internal code logic. >>> # you can match the Enum with string >>> RunningStage.TRAINING == 'train' True """ TRAINING = 'train' - SANITY_CHECKING = "sanity_check" + SANITY_CHECKING = 'sanity_check' VALIDATING = 'validation' TESTING = 'test' PREDICTING = 'predict' TUNING = 'tune' - def is_evaluating(self) -> bool: + def evaluating(self) -> bool: return self in (self.VALIDATING, self.TESTING) +# TODO: this is unused, should remove it def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods which changes state to `entering` before the function execution and `exiting` diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1994d4947c811..29b4b7cc6c571 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -420,12 +420,11 @@ def fit( If the model has a predefined val_dataloaders method this will be skipped """ - # bookkeeping - self._state = TrainerState.RUNNING - # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. + # we reuse fit for other functions. When already set, it shouldn't be modified. + if not self.state.running(): + self.state = TrainerState.FITTING if self._running_stage is None: self.training = True - self.fitting = self.training # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -497,14 +496,12 @@ def fit( if self.is_function_implemented('teardown'): model.teardown('fit') - # return 1 when finished - # used for testing or when we need to know that training succeeded - if self._state != TrainerState.INTERRUPTED: - self._state = TrainerState.FINISHED - + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED self._running_stage = None - self.fitting = False + # return 1 when finished + # used for testing or when we need to know that training succeeded return self.accelerator.results or 1 def pre_dispatch(self): @@ -626,11 +623,9 @@ def run_train(self): except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') - - # user could press ctrl+c many times... only shutdown once + # user could press Ctrl+c many times... only shutdown once if not self.interrupted: - self.interrupted = True - self._state = TrainerState.INTERRUPTED + self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() except (RuntimeError, AssertionError): # if an exception is raised, the finally block is executed and can hide the actual exception @@ -870,8 +865,8 @@ def test( # -------------------- self.verbose_evaluate = verbose + self.state = TrainerState.TESTING self.testing = True - self.fitting = False # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: @@ -891,6 +886,8 @@ def test( ) self.teardown('test') + + assert self.state.stopped() self.testing = False return results @@ -985,8 +982,8 @@ def predict( model = model or self.lightning_module + self.state = TrainerState.PREDICTING self.predicting = True - self.fitting = False if dataloaders and datamodule: raise MisconfigurationException( @@ -1004,6 +1001,7 @@ def predict( self.model = model results = self.fit(model) + assert self.state.stopped() self.predicting = False return results @@ -1030,8 +1028,14 @@ def tune( If the model has a predefined val_dataloaders method this will be skipped """ + self.state = TrainerState.TUNING + self.tuning = True + self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) + assert self.state.stopped() + self.tuning = False + def call_setup_hook(self, model): # call setup after the ddp process has connected stage_name = 'test' if self.evaluating else 'fit' diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8f777321ae7cd..b8114972ec5b3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -62,7 +62,6 @@ def on_trainer_init( ): self.trainer.global_step = 0 self.trainer.current_epoch = 0 - self.trainer.interrupted = False self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING From 18e851c28a956faa774903dd0fa23a06fcfa9882 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 18:25:16 +0100 Subject: [PATCH 12/34] Fix bugs --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/progress.py | 5 +- pytorch_lightning/trainer/states.py | 34 -------- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/tuning.py | 3 + tests/callbacks/test_progress_bar.py | 27 ++++--- tests/trainer/flags/test_fast_dev_run.py | 4 +- tests/trainer/test_states.py | 79 +------------------ tests/trainer/test_trainer.py | 2 +- 10 files changed, 30 insertions(+), 130 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 67c9f89caa86d..0cd0c936f2b86 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -312,7 +312,7 @@ def setup_optimizers(self, trainer): trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if trainer.state is not TrainerState.FITTING: + if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING): return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 19b565a5fc1c1..81d1e5c31a10c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -23,7 +23,6 @@ import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -133,6 +132,7 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState if trainer.state is not TrainerState.FITTING or trainer.sanity_checking: return diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index a4d9a87894b0b..c382e67b21a64 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -380,7 +380,6 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): @@ -412,7 +411,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - if not trainer.sanity_checking: + if trainer.sanity_checking: + reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) + else: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, self.total_val_batches) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index e00b78a8191b3..86c7a544ee173 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -67,37 +67,3 @@ class RunningStage(LightningEnum): def evaluating(self) -> bool: return self in (self.VALIDATING, self.TESTING) - - -# TODO: this is unused, should remove it -def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: - """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods - which changes state to `entering` before the function execution and `exiting` - after the function is executed. If `None` is passed to `entering`, the state is not changed. - If `None` is passed to `exiting`, the state is restored to the state before function execution. - If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`. - """ - - def wrapper(fn) -> Callable: - - @wraps(fn) - def wrapped_fn(self, *args, **kwargs): - if not isinstance(self, pytorch_lightning.Trainer): - return fn(self, *args, **kwargs) - - state_before = self._state - if entering is not None: - self._state = entering - result = fn(self, *args, **kwargs) - - # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted - # we retain INTERRUPTED state - if self._state == TrainerState.INTERRUPTED: - return result - - self._state = exiting if exiting is not None else state_before - return result - - return wrapped_fn - - return wrapper diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 29b4b7cc6c571..c839ecf6afb84 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -479,7 +479,7 @@ def fit( # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() - # dispath `start_training` or `start_testing` or `start_predicting` + # dispatch `start_training` or `start_testing` or `start_predicting` self.dispatch() # plugin will finalized fitting (e.g. ddp_spawn will load trained model) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 06475547b03f2..78810141b1369 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -18,6 +18,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import lr_find @@ -55,6 +56,8 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): if self.trainer.auto_lr_find: self.lr_find(model, update_attr=True) + self.trainer.state = TrainerState.FINISHED + def scale_batch_size( self, model, diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 9fe97717ca404..e4171a8520353 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -218,31 +218,32 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert progress_bar.test_batches_seen == progress_bar.total_test_batches -@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ - pytest.param(0, 0), - pytest.param(5, 7), -]) -def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): +@pytest.mark.parametrize('limit_val_batches', (0, 5)) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches): """ Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. """ class CurrentProgressBar(ProgressBar): + val_pbar_total = 0 + sanity_pbar_total = 0 - def __init__(self): - super().__init__() - self.val_progress_bar_total = 0 + def on_sanity_check_end(self, *args): + self.sanity_pbar_total = self.val_progress_bar.total + super().on_sanity_check_end(*args) - def on_validation_epoch_end(self, trainer, pl_module): - self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total + def on_validation_epoch_end(self, *args): + self.val_pbar_total = self.val_progress_bar.total + super().on_validation_epoch_end(*args) model = BoringModel() progress_bar = CurrentProgressBar() + num_sanity_val_steps = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - num_sanity_val_steps=2, + num_sanity_val_steps=num_sanity_val_steps, limit_train_batches=1, limit_val_batches=limit_val_batches, callbacks=[progress_bar], @@ -250,7 +251,9 @@ def on_validation_epoch_end(self, trainer, pl_module): checkpoint_callback=False, ) trainer.fit(model) - assert trainer.progress_bar_callback.val_progress_bar_total == expected + + assert progress_bar.sanity_pbar_total == min(num_sanity_val_steps, limit_val_batches) + assert progress_bar.val_pbar_total == limit_val_batches def test_progress_bar_default_value(tmpdir): diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 221951e788284..3c679acffc191 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -20,8 +20,8 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False, - auto_lr_find=True if tuner_alg == 'learning rate finder' else False, + auto_scale_batch_size=tuner_alg == 'batch size scaler', + auto_lr_find=tuner_alg == 'learning rate finder', fast_dev_run=True ) expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 4e067fe22feb6..34f766b1ab21e 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -14,7 +14,7 @@ import pytest from pytorch_lightning import Callback, Trainer -from pytorch_lightning.trainer.states import trainer_state, TrainerState +from pytorch_lightning.trainer.states import TrainerState from tests.base import EvalModelTemplate @@ -36,79 +36,6 @@ def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_i self.trainer_state = trainer.state -def test_state_decorator_nothing_passed(tmpdir): - """ Test that state is not changed if nothing is passed to a decorator""" - - @trainer_state() - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.INITIALIZING - assert trainer.state == TrainerState.INITIALIZING - - -def test_state_decorator_entering_only(tmpdir): - """ Tests that state is set to entering inside a run function and restored to the previous value after. """ - - @trainer_state(entering=TrainerState.RUNNING) - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.RUNNING - assert trainer.state == TrainerState.INITIALIZING - - -def test_state_decorator_exiting_only(tmpdir): - """ Tests that state is not changed inside a run function and set to `exiting` after. """ - - @trainer_state(exiting=TrainerState.FINISHED) - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.INITIALIZING - assert trainer.state == TrainerState.FINISHED - - -def test_state_decorator_entering_and_exiting(tmpdir): - """ Tests that state is set to `entering` inside a run function and set ot `exiting` after. """ - - @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.RUNNING - assert trainer.state == TrainerState.FINISHED - - -def test_state_decorator_interrupt(tmpdir): - """ Tests that state remains `INTERRUPTED` is its set in run function. """ - - @trainer_state(exiting=TrainerState.FINISHED) - def test_method(self): - self._state = TrainerState.INTERRUPTED - - trainer = Trainer(default_root_dir=tmpdir) - - test_method(trainer) - assert trainer.state == TrainerState.INTERRUPTED - - def test_initialize_state(tmpdir): """ Tests that state is INITIALIZE after Trainer creation """ trainer = Trainer(default_root_dir=tmpdir) @@ -133,7 +60,7 @@ def test_running_state_during_fit(tmpdir, extra_params): trainer.fit(model) - assert snapshot_callback.trainer_state == TrainerState.RUNNING + assert snapshot_callback.trainer_state.running() @pytest.mark.parametrize( @@ -170,7 +97,7 @@ def test_running_state_during_test(tmpdir): trainer.test(model) - assert snapshot_callback.trainer_state == TrainerState.RUNNING + assert snapshot_callback.trainer_state.running() def test_finished_state_after_test(tmpdir): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d48e86e90c535..57f93677a532b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -450,7 +450,7 @@ def mock_save_function(filepath, *args): ) checkpoint_callback.save_function = mock_save_function trainer = Trainer() - trainer.fitting = True + trainer.state = TrainerState.FITTING # emulate callback's calls during the training for i, loss in enumerate(losses): From dec84ec986b0fc2d07d934e2c0fe7c66f3988d59 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 19:42:01 +0100 Subject: [PATCH 13/34] Fix bugs --- pytorch_lightning/core/lightning.py | 4 -- pytorch_lightning/overrides/base.py | 13 ++-- .../plugins/training_type/rpc_sequential.py | 9 ++- .../logger_connector/epoch_result_store.py | 6 +- .../logger_connector/logger_connector.py | 6 +- pytorch_lightning/trainer/training_loop.py | 1 - tests/overrides/test_data_parallel.py | 68 ++++++++++++------- 7 files changed, 59 insertions(+), 48 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c4d63cff4637b..669411f87797f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -171,10 +171,6 @@ def automatic_optimization(self) -> bool: """ return self._automatic_optimization - @property - def running_stage(self) -> Optional["RunningStage"]: - return self.trainer._running_stage if self.trainer else None - @automatic_optimization.setter def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index bf7be4ba774b2..2e9c37e723622 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -18,7 +18,6 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.warnings import WarningCache @@ -43,9 +42,9 @@ def __init__(self, pl_module: LightningModule): self.module = pl_module def forward(self, *inputs, **kwargs): - running_stage = self.module.running_stage + trainer = self.module.trainer - if running_stage is RunningStage.TRAINING: + if trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as @@ -53,18 +52,18 @@ def forward(self, *inputs, **kwargs): # `require_backward_grad_sync` will be reset in the # ddp_plugin ``post_training_step`` hook if not self.module.automatic_optimization: - self.module.trainer.model.require_backward_grad_sync = False + trainer.model.require_backward_grad_sync = False warn_if_output_is_none(output, "training_step") - elif running_stage is RunningStage.TESTING: + elif trainer.testing: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - elif running_stage in (RunningStage.VALIDATING, RunningStage.SANITY_CHECKING): + elif trainer.sanity_checking or trainer.validating: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - elif running_stage is RunningStage.PREDICTING: + elif trainer.predicting: output = self.module.predict(*inputs, **kwargs) warn_if_output_is_none(output, "predict") diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index df61006a24406..a05f95bd36122 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -208,8 +208,7 @@ def _skip_init_connections(self): Returns: Whether to skip initialization """ - # TODO: should this use trainer.fitting? same for other occurrences in this file - return torch_distrib.is_initialized() and self.lightning_module.running_stage is RunningStage.TESTING + return torch_distrib.is_initialized() and self.lightning_module.trainer.state is not TrainerState.FITTING def init_model_parallel_groups(self): num_model_parallel = 1 # TODO currently no support for vertical model parallel @@ -232,7 +231,7 @@ def _infer_check_num_gpus(self): return self.world_size def handle_transferred_pipe_module(self) -> None: - if not self.lightning_module.running_stage == RunningStage.TESTING: + if self.lightning_module.trainer.state is TrainerState.FITTING: torch_distrib.barrier() # Ensure we await main process initialization # Add trainer/configure_optimizers to the pipe model for access in all worker processes rpc_pipe.PipeModel.trainer = self.lightning_module.trainer @@ -244,7 +243,7 @@ def init_pipe_module(self) -> None: # Create pipe_module model = self.lightning_module self._find_and_init_pipe_module(model) - if not self.lightning_module.running_stage == RunningStage.TESTING: + if self.lightning_module.trainer.state is TrainerState.FITTING: torch_distrib.barrier() # Ensure we join main process initialization model.sequential_module.foreach_worker(register_optimizers, include_self=True) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index e35beb56310a4..aded4a532d837 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -17,7 +17,7 @@ import torch from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DistributedType, LightningEnum @@ -339,7 +339,9 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics.update(epoch_log_metrics) callback_metrics.update(forked_metrics) - if not is_train and self.trainer.evaluating: + # TODO(carmocca): when we implement flushing the logger connector metrics after + # the trainer.state changes, this should check trainer.evaluating instead + if not is_train and self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): logger_connector.evaluation_callback_metrics.update(callback_metrics) # update callback_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 4e70d683f5001..c0481365a5eb2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -337,7 +337,7 @@ def _track_callback_metrics(self, eval_results): flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.evaluating: + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks @@ -352,7 +352,7 @@ def _track_callback_metrics(self, eval_results): flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.evaluating: + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics): @@ -370,7 +370,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric callback_metrics.update(log_metrics) callback_metrics.update(prog_bar_metrics) self.trainer.logger_connector.callback_metrics.update(callback_metrics) - if self.trainer.evaluating: + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics) if len(dataloader_result_metrics) > 0: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b8114972ec5b3..4f56127dbe8c9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -148,7 +148,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - self.trainer.fitting = False def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 90bb6fac88457..5d948e5c158d8 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -19,7 +19,13 @@ LightningParallelModule, LightningDistributedModule, ]) -def test_lightning_wrapper_module_methods(wrapper_class): +@pytest.mark.parametrize("stage", [ + ("training", "training_step"), + ("testing", "test_step"), + ("validating", "validation_step"), + ("predicting", "predict"), +]) +def test_lightning_wrapper_module_methods(wrapper_class, stage): """ Test that the LightningWrapper redirects .forward() to the LightningModule methods. """ pl_module = MagicMock() wrapped_module = wrapper_class(pl_module) @@ -27,52 +33,62 @@ def test_lightning_wrapper_module_methods(wrapper_class): batch = torch.rand(5) batch_idx = 3 - pl_module.running_stage = RunningStage.TRAINING - wrapped_module(batch, batch_idx) - pl_module.training_step.assert_called_with(batch, batch_idx) + prop, step = stage + pl_module.trainer.sanity_checking = False + for p in ("training", "testing", "validating", "predicting"): + setattr(pl_module.trainer, p, p == prop) - pl_module.running_stage = RunningStage.TESTING wrapped_module(batch, batch_idx) - pl_module.test_step.assert_called_with(batch, batch_idx) - pl_module.running_stage = RunningStage.EVALUATING - wrapped_module(batch, batch_idx) - pl_module.validation_step.assert_called_with(batch, batch_idx) - - pl_module.running_stage = RunningStage.PREDICTING - wrapped_module(batch) - pl_module.predict.assert_called_with(batch) + getattr(pl_module, step).assert_called_with(batch, batch_idx) @pytest.mark.parametrize("wrapper_class", [ LightningParallelModule, LightningDistributedModule, ]) -def test_lightning_wrapper_module_warn_none_output(wrapper_class): +@pytest.mark.parametrize("stage", [ + ("training", "training_step"), + ("testing", "test_step"), + ("validating", "validation_step"), +]) +def test_lightning_wrapper_module_warn_none_output(wrapper_class, stage): """ Test that the LightningWrapper module warns about forgotten return statement. """ warning_cache.clear() pl_module = MagicMock() + + prop, step = stage + pl_module.trainer.sanity_checking = False + for p in ("training", "testing", "validating", "predicting"): + setattr(pl_module.trainer, p, p == prop) + wrapped_module = wrapper_class(pl_module) - pl_module.training_step.return_value = None - pl_module.validation_step.return_value = None - pl_module.test_step.return_value = None + getattr(pl_module, step).return_value = None - with pytest.warns(UserWarning, match="Your training_step returned None"): - pl_module.running_stage = RunningStage.TRAINING + with pytest.warns(UserWarning, match=f"Your {step} returned None"): wrapped_module() - with pytest.warns(UserWarning, match="Your test_step returned None"): - pl_module.running_stage = RunningStage.TESTING - wrapped_module() - with pytest.warns(UserWarning, match="Your validation_step returned None"): - pl_module.running_stage = RunningStage.EVALUATING - wrapped_module() +@pytest.mark.parametrize("wrapper_class", [ + LightningParallelModule, + LightningDistributedModule, +]) +def test_lightning_wrapper_module_no_warn(wrapper_class): + warning_cache.clear() + pl_module = MagicMock() + + pl_module.trainer.sanity_checking = False + pl_module.trainer.training = False + pl_module.trainer.testing = False + pl_module.trainer.validating = False + pl_module.trainer.predicting = False + + wrapped_module = wrapper_class(pl_module) with pytest.warns(None) as record: - pl_module.running_stage = None wrapped_module() + pl_module.assert_called() assert not record From 0b21f4d05a03b382d327de5fa1d301fa32cffc8e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 21:19:57 +0100 Subject: [PATCH 14/34] Fix tests --- .../plugins/training_type/deepspeed.py | 5 +- tests/trainer/test_states.py | 132 ++++++------------ 2 files changed, 47 insertions(+), 90 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 75e5bf74be643..82bb7c53dbfe1 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -213,7 +213,7 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.lightning_module.trainer.training: + if hasattr(self.lightning_module, "trainer") and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model) @@ -249,8 +249,7 @@ def _initialize_deepspeed_train(self, model): ) # set optimizer for save/load, but deepspeed manages the specific optimizer logic - trainer = self.lightning_module.trainer - trainer.optimizers = [optimizer] + self.lightning_module.trainer.optimizers = [optimizer] self.model = model def _initialize_deepspeed_inference(self, model): diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 34f766b1ab21e..1d6e4c295261e 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -15,25 +15,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.base import EvalModelTemplate - - -class StateSnapshotCallback(Callback): - """ Allows to shapshot the state inside a particular trainer method. """ - - def __init__(self, snapshot_method: str): - super().__init__() - assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] - self.snapshot_method = snapshot_method - self.trainer_state = None - - def on_batch_start(self, trainer, pl_module): - if self.snapshot_method == 'on_batch_start': - self.trainer_state = trainer.state - - def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - if self.snapshot_method == 'on_test_batch_start': - self.trainer_state = trainer.state +from tests.helpers import BoringModel def test_initialize_state(tmpdir): @@ -48,71 +30,53 @@ def test_initialize_state(tmpdir): pytest.param(dict(max_steps=1), id='Single-Step'), ] ) -def test_running_state_during_fit(tmpdir, extra_params): - """ Tests that state is set to RUNNING during fit """ - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start') - - trainer = Trainer(callbacks=[snapshot_callback], default_root_dir=tmpdir, **extra_params) - - trainer.fit(model) - - assert snapshot_callback.trainer_state.running() - - -@pytest.mark.parametrize( - "extra_params", [ - pytest.param(dict(fast_dev_run=True), id='Fast-Run'), - pytest.param(dict(max_steps=1), id='Single-Step'), - ] -) -def test_finished_state_after_fit(tmpdir, extra_params): - """ Tests that state is FINISHED after fit """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - trainer = Trainer(default_root_dir=tmpdir, **extra_params) +def test_trainer_state_while_running(tmpdir, extra_params): + trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True) + fdr = trainer.fast_dev_run + class TestModel(BoringModel): + def __init__(self, expected_state): + super().__init__() + self.expected_state = expected_state + self.called = set() + self.lr = 0.1 + + def on_batch_start(self, *_): + assert self.trainer.state == self.expected_state + + def on_train_batch_start(self, *_): + self.called.add("train") + assert self.trainer.training + + def on_sanity_check_start(self, *_): + self.called.add("sanity") + assert self.trainer.sanity_checking + + def on_validation_batch_start(self, *_): + self.called.add("validation") + assert self.trainer.validating or self.trainer.sanity_checking + + def on_test_batch_start(self, *_): + self.called.add("test") + assert self.trainer.testing + + model = TestModel(TrainerState.TUNING) + trainer.tune(model) + if fdr: + assert not model.called + else: + assert model.called == {'train', 'validation'} + assert trainer.state is TrainerState.FINISHED + + model = TestModel(TrainerState.FITTING) trainer.fit(model) + assert model.called == {'train', 'validation'} if fdr else {'train', 'sanity', 'validation'} + assert trainer.state is TrainerState.FINISHED - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - - -def test_running_state_during_test(tmpdir): - """ Tests that state is set to RUNNING during test """ - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start') - - trainer = Trainer( - callbacks=[snapshot_callback], - default_root_dir=tmpdir, - fast_dev_run=True, - ) - + model = TestModel(TrainerState.TESTING) trainer.test(model) - - assert snapshot_callback.trainer_state.running() - - -def test_finished_state_after_test(tmpdir): - """ Tests that state is FINISHED after fit """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - ) - - trainer.test(model) - - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert model.called == {'test'} + assert trainer.state is TrainerState.FINISHED @pytest.mark.parametrize( @@ -123,19 +87,13 @@ def test_finished_state_after_test(tmpdir): ) def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): """ Tests that state is set to INTERRUPTED on KeyboardInterrupt """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + model = BoringModel() class InterruptCallback(Callback): - - def __init__(self): - super().__init__() - def on_batch_start(self, trainer, pl_module): raise KeyboardInterrupt trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmpdir, **extra_params) trainer.fit(model) - assert trainer.state == TrainerState.INTERRUPTED From 8cdac8e2b7a1344fb7861b105f66a92633887193 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 22:14:24 +0100 Subject: [PATCH 15/34] Update CHANGELOG. Add deprecation warning. Fix tests --- CHANGELOG.md | 15 +++++ .../plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/trainer/deprecated_api.py | 3 +- pytorch_lightning/trainer/states.py | 6 +- tests/deprecated_api/test_remove_1-5.py | 6 ++ tests/plugins/test_deepspeed_plugin.py | 1 + .../trainer/flags/test_val_check_interval.py | 66 ++----------------- 7 files changed, 32 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1d4b1d6c983..80e0fcc786d68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,12 +15,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) +- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + ### Changed +- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `trainer.evaluate` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + ### Deprecated +- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 82bb7c53dbfe1..7250cf191163b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -213,7 +213,7 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if hasattr(self.lightning_module, "trainer") and self.lightning_module.trainer.training: + if self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index ff3bc2a876629..20d323b644286 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -157,6 +157,7 @@ def get_model(self) -> LightningModule: @property def running_sanity_check(self) -> bool: rank_zero_warn( - "TODO", DeprecationWarning + "The use of `Trainer.running_sanity_check` is deprecated in favor of `Trainer.sanity_checking`" + " and will be removed in v1.5", DeprecationWarning ) return self.sanity_checking diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 86c7a544ee173..760069bb55be1 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import wraps -from typing import Callable, Optional - -import pytorch_lightning from pytorch_lightning.utilities import LightningEnum @@ -46,7 +42,7 @@ def stopped(self) -> bool: def running(self) -> bool: return self in (self.FITTING, self.VALIDATING, self.TESTING, self.PREDICTING, self.TUNING) -# + class RunningStage(LightningEnum): """Current running stage. diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index e87fb5c2ebbb2..80698431e0070 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -54,3 +54,9 @@ def on_save_checkpoint(self, *args): trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] with no_warning_call(DeprecationWarning): trainer.save_checkpoint(filepath) + + +def test_v1_5_0_running_sanity_check(): + trainer = Trainer() + with pytest.deprecated_call(match='deprecated in favor of `Trainer.sanity_checking`'): + assert not trainer.running_sanity_check diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 903855fd2c0eb..05e9238fae524 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -38,6 +38,7 @@ def test_deepspeed_lightning_module_precision(tmpdir): """ model = BoringModel() + model.trainer = Trainer() module = LightningDeepSpeedModule(model, precision=16) module.cuda().half() diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index af9ae06e8445b..7f3e9f6287cd8 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -18,7 +18,8 @@ @pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_1(tmpdir, max_epochs): +@pytest.mark.parametrize('denominator', [1, 3, 4]) +def test_val_check_interval(tmpdir, max_epochs, denominator): class TestModel(BoringModel): @@ -31,71 +32,16 @@ def on_train_epoch_start(self) -> None: self.train_epoch_calls += 1 def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: + if not self.trainer.sanity_checking: self.val_epoch_calls += 1 model = TestModel() trainer = Trainer( max_epochs=max_epochs, - val_check_interval=1.0, + val_check_interval=1 / denominator, logger=False, ) trainer.fit(model) - assert model.val_epoch_calls == max_epochs - - -@pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_quarter(tmpdir, max_epochs): - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.train_epoch_calls = 0 - self.val_epoch_calls = 0 - - def on_train_epoch_start(self) -> None: - self.train_epoch_calls += 1 - - def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: - self.val_epoch_calls += 1 - - model = TestModel() - trainer = Trainer( - max_epochs=max_epochs, - val_check_interval=0.25, - logger=False, - ) - trainer.fit(model) - - assert model.val_epoch_calls == max_epochs * 4 - - -@pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_third(tmpdir, max_epochs): - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.train_epoch_calls = 0 - self.val_epoch_calls = 0 - - def on_train_epoch_start(self) -> None: - self.train_epoch_calls += 1 - - def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: - self.val_epoch_calls += 1 - - model = TestModel() - trainer = Trainer( - max_epochs=max_epochs, - val_check_interval=0.33, - logger=False, - ) - trainer.fit(model) - - assert model.val_epoch_calls == max_epochs * 3 + assert model.train_epoch_calls == max_epochs + assert model.val_epoch_calls == max_epochs * denominator From 73916f481bba5cff881b647658c7b5d567c9e863 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 22:17:10 +0100 Subject: [PATCH 16/34] Unused imports --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/trainer/deprecated_api.py | 1 - pytorch_lightning/trainer/training_loop.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 669411f87797f..62133a31a9f0e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args if TYPE_CHECKING: - from pytorch_lightning.trainer.states import RunningStage + pass class LightningModule( diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 20d323b644286..2693012ec23f1 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -14,7 +14,6 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4f56127dbe8c9..d3505feebcd15 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info From 6c62eecf283bd6c2192c38abdb643cea788ea7d1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 22:34:46 +0100 Subject: [PATCH 17/34] Optional trainer --- pytorch_lightning/core/lightning.py | 5 +---- pytorch_lightning/overrides/base.py | 8 ++++---- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 1 - 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 62133a31a9f0e..e89234a4e37ae 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -24,7 +24,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -44,9 +44,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args -if TYPE_CHECKING: - pass - class LightningModule( ABC, diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 2e9c37e723622..170cdc4600bb4 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -44,7 +44,7 @@ def __init__(self, pl_module: LightningModule): def forward(self, *inputs, **kwargs): trainer = self.module.trainer - if trainer.training: + if trainer and trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as @@ -55,15 +55,15 @@ def forward(self, *inputs, **kwargs): trainer.model.require_backward_grad_sync = False warn_if_output_is_none(output, "training_step") - elif trainer.testing: + elif trainer and trainer.testing: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - elif trainer.sanity_checking or trainer.validating: + elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - elif trainer.predicting: + elif trainer and trainer.predicting: output = self.module.predict(*inputs, **kwargs) warn_if_output_is_none(output, "predict") diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 7250cf191163b..f231186c4d85c 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -213,7 +213,7 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.lightning_module.trainer.training: + if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 05e9238fae524..903855fd2c0eb 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -38,7 +38,6 @@ def test_deepspeed_lightning_module_precision(tmpdir): """ model = BoringModel() - model.trainer = Trainer() module = LightningDeepSpeedModule(model, precision=16) module.cuda().half() From 0a211a26e6d9aefa0a082b6d91fe2b2a03bb40e3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 23:07:18 +0100 Subject: [PATCH 18/34] More deprecation. More refactoring --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 4 +-- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 2 +- .../training_type/training_type_plugin.py | 4 +-- pytorch_lightning/trainer/deprecated_api.py | 11 +++++++- pytorch_lightning/trainer/trainer.py | 26 ++++++++----------- tests/deprecated_api/test_remove_1-5.py | 8 +++++- 8 files changed, 37 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80e0fcc786d68..10df4d52de505 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Deprecated `trainer.tested_ckpt_path` in favor of `trainer.evaluated_ckpt_path` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8fb6f88a322c2..a7c6abead41c8 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -82,8 +82,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: def start_training(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_training(trainer) - def start_testing(self, trainer: 'Trainer') -> None: - self.training_type_plugin.start_testing(trainer) + def start_evaluating(self, trainer: 'Trainer') -> None: + self.training_type_plugin.start_evaluating(trainer) def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index d07006ef4fb6d..7a9fa58c282e0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -157,7 +157,7 @@ def new_process(self, process_idx, trainer, mp_queue): self.barrier() - results = trainer.train_or_test_or_predict() + results = trainer.run_stage() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(results) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 516b96b082272..d55beadc605c7 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -98,7 +98,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.save_checkpoint = self.save_checkpoint self.barrier() - results = trainer.train_or_test_or_predict() + results = trainer.run_stage() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 77fae2746c402..20d36cd11faf5 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -117,9 +117,9 @@ def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop self._results = trainer.run_train() - def start_testing(self, trainer: 'Trainer') -> None: + def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_test() + self._results = trainer.run_evaluate() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 2693012ec23f1..3ddcabd7e570d 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -137,6 +137,7 @@ class DeprecatedTrainerAttributes: accelerator: Accelerator lightning_module = LightningModule sanity_checking: bool + evaluated_ckpt_path: str @property def accelerator_backend(self) -> Accelerator: @@ -156,7 +157,15 @@ def get_model(self) -> LightningModule: @property def running_sanity_check(self) -> bool: rank_zero_warn( - "The use of `Trainer.running_sanity_check` is deprecated in favor of `Trainer.sanity_checking`" + "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking`" " and will be removed in v1.5", DeprecationWarning ) return self.sanity_checking + + @property + def tested_ckpt_path(self): + rank_zero_warn( + '`Trainer.tested_ckpt_path` has been renamed to `Trainer.evaluated_ckpt_path`' + ' in v1.1 and will be removed in v1.3.', DeprecationWarning + ) + return self.evaluated_ckpt_path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1f75eb09510ee..b0c8e329692ea 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -509,25 +509,21 @@ def post_dispatch(self): self.accelerator.teardown() def dispatch(self): - if self.testing: - self.accelerator.start_testing(self) - + if self.evaluating: + self.accelerator.start_evaluating(self) elif self.predicting: self.accelerator.start_predicting(self) - else: self.accelerator.start_training(self) - def train_or_test_or_predict(self): - if self.testing: - results = self.run_test() - + def run_stage(self): + results = None + if self.evaluating: + results = self.run_evaluate() elif self.predicting: results = self.run_predict() - else: - results = self.run_train() - + self.run_train() return results def _pre_training_routine(self): @@ -562,7 +558,7 @@ def _pre_training_routine(self): if self.is_function_implemented("on_pretrain_routine_end"): ref_model.on_pretrain_routine_end() - def run_train(self): + def run_train(self) -> None: self._pre_training_routine() @@ -741,13 +737,13 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - def run_test(self): + def run_evaluate(self): if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - assert self.testing + assert self.evaluating - with self.profiler.profile("run_test_evaluation"): + with self.profiler.profile(f"run_{self._running_stage}_evaluation"): eval_loop_results, _ = self.run_evaluation() if len(eval_loop_results) == 0: diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 80698431e0070..60cf3fc6770aa 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -58,5 +58,11 @@ def on_save_checkpoint(self, *args): def test_v1_5_0_running_sanity_check(): trainer = Trainer() - with pytest.deprecated_call(match='deprecated in favor of `Trainer.sanity_checking`'): + with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): assert not trainer.running_sanity_check + + +def test_v1_5_0_tested_ckpt_path(): + trainer = Trainer() + with pytest.deprecated_call(match='has been renamed to `Trainer.evaluated_ckpt_path`'): + assert not trainer.tested_ckpt_path From 6c6752c65133a8b70cec1ff5165ba9b2686dea63 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 23:09:09 +0100 Subject: [PATCH 19/34] Correct version --- pytorch_lightning/trainer/deprecated_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 3ddcabd7e570d..5d9809fdd335d 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -158,7 +158,7 @@ def get_model(self) -> LightningModule: def running_sanity_check(self) -> bool: rank_zero_warn( "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking`" - " and will be removed in v1.5", DeprecationWarning + " and will be removed in v1.5.", DeprecationWarning ) return self.sanity_checking @@ -166,6 +166,6 @@ def running_sanity_check(self) -> bool: def tested_ckpt_path(self): rank_zero_warn( '`Trainer.tested_ckpt_path` has been renamed to `Trainer.evaluated_ckpt_path`' - ' in v1.1 and will be removed in v1.3.', DeprecationWarning + ' and will be removed in v1.5.', DeprecationWarning ) return self.evaluated_ckpt_path From 34ae4180e5a3036d203c9ca635208a96ac27f743 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Feb 2021 01:58:26 +0100 Subject: [PATCH 20/34] Use properties --- pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/states.py | 3 +++ pytorch_lightning/trainer/trainer.py | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index d0a8e2328a873..de70d5091a607 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -478,7 +478,7 @@ def validating(self, val: bool) -> None: @property def evaluating(self) -> bool: - return self._running_stage and self._running_stage.evaluating() + return self._running_stage and self._running_stage.evaluating @property def sanity_checking(self) -> bool: diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 760069bb55be1..d0c2ded659f67 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -36,9 +36,11 @@ class TrainerState(LightningEnum): FINISHED = 'FINISHED' INTERRUPTED = 'INTERRUPTED' + @property def stopped(self) -> bool: return self in (self.FINISHED, self.INTERRUPTED) + @property def running(self) -> bool: return self in (self.FITTING, self.VALIDATING, self.TESTING, self.PREDICTING, self.TUNING) @@ -61,5 +63,6 @@ class RunningStage(LightningEnum): PREDICTING = 'predict' TUNING = 'tune' + @property def evaluating(self) -> bool: return self in (self.VALIDATING, self.TESTING) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b0c8e329692ea..ee3d98ddba0f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -418,7 +418,7 @@ def fit( """ # we reuse fit for other functions. When already set, it shouldn't be modified. - if not self.state.running(): + if not self.state.running: self.state = TrainerState.FITTING if self._running_stage is None: self.training = True @@ -880,7 +880,7 @@ def test( self.teardown('test') - assert self.state.stopped() + assert self.state.stopped self.testing = False return results @@ -994,7 +994,7 @@ def predict( self.model = model results = self.fit(model) - assert self.state.stopped() + assert self.state.stopped self.predicting = False return results @@ -1026,7 +1026,7 @@ def tune( self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) - assert self.state.stopped() + assert self.state.stopped self.tuning = False def call_setup_hook(self, model): From 24f1c1e2a7cc2bed00183a38301227b025639084 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 1 Mar 2021 12:00:20 +0100 Subject: [PATCH 21/34] Address comments --- CHANGELOG.md | 2 +- .../plugins/training_type/tpu_spawn.py | 1 - .../trainer/configuration_validator.py | 5 +++-- pytorch_lightning/trainer/deprecated_api.py | 9 -------- pytorch_lightning/trainer/evaluation_loop.py | 3 ++- pytorch_lightning/trainer/trainer.py | 22 +++++++++---------- tests/deprecated_api/test_remove_1-5.py | 6 ----- tests/trainer/test_trainer.py | 6 ++--- 8 files changed, 20 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10df4d52de505..4be9f3a62e297 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Changed `trainer.evaluate` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) ### Deprecated diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d55beadc605c7..a5d8898c57f22 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -230,7 +230,6 @@ def post_dispatch(self) -> None: def __load_weights_on_main_process(self) -> None: model = self.lightning_module - assert hasattr(model, "trainer") # load weights if not interrupted if on_colab_kaggle() and model.trainer.state is TrainerState.FITTING: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 8b7ce13004636..1bf38048ee159 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -82,9 +82,10 @@ def __verify_train_loop_configuration(self, model): ) def __verify_eval_loop_configuration(self, model): - stage = self.trainer._running_stage + stage = "val" if self.trainer.validating else "test" + + loader_name = f'{stage}_dataloader' step_name = f'{stage}_step' - loader_name = 'val_dataloader' if self.trainer.validating else f'{stage}_dataloader' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 5d9809fdd335d..70db8b36814ca 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -137,7 +137,6 @@ class DeprecatedTrainerAttributes: accelerator: Accelerator lightning_module = LightningModule sanity_checking: bool - evaluated_ckpt_path: str @property def accelerator_backend(self) -> Accelerator: @@ -161,11 +160,3 @@ def running_sanity_check(self) -> bool: " and will be removed in v1.5.", DeprecationWarning ) return self.sanity_checking - - @property - def tested_ckpt_path(self): - rank_zero_warn( - '`Trainer.tested_ckpt_path` has been renamed to `Trainer.evaluated_ckpt_path`' - ' and will be removed in v1.5.', DeprecationWarning - ) - return self.evaluated_ckpt_path diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index b45d58c836648..9979a5cda87bd 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -39,7 +39,8 @@ def on_trainer_init(self): self.trainer.val_dataloaders = None # .validate() and .test() set this when they load a checkpoint - self.trainer.evaluated_ckpt_path = None + self.trainer.validated_ckpt_path = None + self.trainer.tested_ckpt_path = None # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ee3d98ddba0f9..09377e9fbacc8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -875,7 +875,7 @@ def test( results = ( self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else - self.__evaluate_using_best_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) + self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) ) self.teardown('test') @@ -885,7 +885,7 @@ def test( return results - def __evaluate_using_best_weights( + def __evaluate_using_weights( self, model, ckpt_path: Optional[str] = None, @@ -894,7 +894,7 @@ def __evaluate_using_best_weights( # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( - 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' + 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights @@ -905,8 +905,8 @@ def __evaluate_using_best_weights( if len(ckpt_path) == 0: rank_zero_warn( - f'.test() found no path for the best weights, {ckpt_path}. Please ' - f'specify a path for a checkpoint .test(ckpt_path=PATH)' + f'`.test()` found no path for the best weights, {ckpt_path}. Please' + ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' ) return {} if not self._device_type == DeviceType.TPU: @@ -919,8 +919,12 @@ def __evaluate_using_best_weights( if dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + if self.validating: + self.validated_ckpt_paath = ckpt_path + else: + self.tested_ckpt_path = ckpt_path + # run test - self.evaluated_ckpt_path = ckpt_path results = self.fit(model) # teardown @@ -1034,11 +1038,7 @@ def call_setup_hook(self, model): stage_name = 'test' if self.evaluating else 'fit' if self.datamodule is not None: - called = { - 'fit': self.datamodule.has_setup_fit, - 'test': self.datamodule.has_setup_test, - }[stage_name] - + called = getattr(self.datamodule, f'has_setup_{stage_name}') if not called: self.datamodule.setup(stage_name) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 60cf3fc6770aa..ac234e22e79a2 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -60,9 +60,3 @@ def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): assert not trainer.running_sanity_check - - -def test_v1_5_0_tested_ckpt_path(): - trainer = Trainer() - with pytest.deprecated_call(match='has been renamed to `Trainer.evaluated_ckpt_path`'): - assert not trainer.tested_ckpt_path diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 57f93677a532b..b98e86cbe099a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -709,12 +709,12 @@ def test_checkpoint_path(tmpdir, ckpt_path, save_top_k): trainer.test(ckpt_path=ckpt_path) else: trainer.test(ckpt_path=ckpt_path) - assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path + assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training trainer.test(ckpt_path=ckpt_path) - assert trainer.evaluated_ckpt_path is None + assert trainer.tested_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -726,7 +726,7 @@ def test_checkpoint_path(tmpdir, ckpt_path, save_top_k): )[0].absolute() ) trainer.test(ckpt_path=ckpt_path) - assert trainer.evaluated_ckpt_path == ckpt_path + assert trainer.tested_ckpt_path == ckpt_path def test_disabled_training(tmpdir): From 1e5d84d39a01da30256e6a6607ba156619baf33a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Mar 2021 04:02:59 +0100 Subject: [PATCH 22/34] flake8 --- pytorch_lightning/core/lightning.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 69a5f3f9dae55..8bc73745ae6bb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -43,9 +43,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args -if TYPE_CHECKING: - from pytorch_lightning.trainer.states import RunningStage - log = logging.getLogger(__name__) @@ -69,7 +66,6 @@ class LightningModule( "on_gpu", "current_epoch", "global_step", - "running_stage", "global_rank", "local_rank", "logger", From dccf6034e980971118771501e3b6fedc52348195 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Mar 2021 04:22:06 +0100 Subject: [PATCH 23/34] Missed renamings --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/horovod.py | 4 ++-- pytorch_lightning/plugins/training_type/rpc_sequential.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- pytorch_lightning/trainer/trainer.py | 8 ++++---- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b415267e8c511..b72d679f9a637 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -110,7 +110,7 @@ def start_training(self, trainer): # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] - def start_testing(self, trainer): + def start_evaluating(self, trainer): mp.spawn(self.new_process, **self.mp_spawn_kwargs) def start_predicting(self, trainer): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8fe52190fd7bb..103a4f4af454a 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -101,9 +101,9 @@ def start_training(self, trainer): # Make sure all workers have finished training before returning to the user hvd.join() - def start_testing(self, trainer): + def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_test() + self._results = trainer.evaluate() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index a05f95bd36122..d2af09f4c1038 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -333,9 +333,9 @@ def start_training(self, trainer) -> None: if self.main_rpc_process: super().start_training(trainer) - def start_testing(self, trainer) -> None: + def start_evaluating(self, trainer) -> None: if self.main_rpc_process: - super().start_testing(trainer) + super().start_evaluating(trainer) class LightningPipeModule(nn.Module): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index f64e079433924..081f3c38a4abb 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -283,7 +283,7 @@ def start_training(self, trainer) -> None: self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) - def start_testing(self, trainer) -> None: + def start_evaluating(self, trainer) -> None: self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e2c17317cf576..01d590a94828e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -456,16 +456,16 @@ def fit( # | || # trainer.dispatch || LIGHTNING # | || - # start_training or start_testing or start_predicting call || FLOW + # start_training or start_evaluating or start_predicting call || FLOW # from `accelerator` || # | || DIRECTION - # run_train or run_test or run_predict call || + # run_train or run_evaluate or run_predict call || # from `trainer` || # | || # results \/ # This is used to guide readers to the core loops: train, test, predict. # `run_predict` is the simplest to understand, use `Go to Definition` to read it :) - # Search for `start_training` or `start_testing` or `start_predicting` in + # Search for `start_training` or `start_evaluating` or `start_predicting` in # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions. # ---------------------------- @@ -477,7 +477,7 @@ def fit( # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() - # dispatch `start_training` or `start_testing` or `start_predicting` + # dispatch `start_training` or `start_evaluating` or `start_predicting` self.dispatch() # plugin will finalized fitting (e.g. ddp_spawn will load trained model) From 56b4494ca7c347d85499f9245c4320be84b5cceb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Mar 2021 04:22:58 +0100 Subject: [PATCH 24/34] Typo --- pytorch_lightning/plugins/training_type/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 103a4f4af454a..2fe3906cb01d0 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -103,7 +103,7 @@ def start_training(self, trainer): def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.evaluate() + self._results = trainer.run_evaluate() # Make sure all workers have finished training before returning to the user hvd.join() From 94e93ad38d0f99a36c6d42040b1e44eaffe1a42c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Mar 2021 13:33:20 +0100 Subject: [PATCH 25/34] is -> == It is recommended to use for Enums since they are singletons, however, since the LightningEnum subclasses str, it's not a good idea in case a user sets the state/stage with a str --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- .../plugins/training_type/ddp_spawn.py | 4 ++-- .../plugins/training_type/rpc_sequential.py | 6 +++--- pytorch_lightning/plugins/training_type/sharded.py | 2 +- .../plugins/training_type/sharded_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 6 +++--- .../logger_connector/epoch_result_store.py | 2 +- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 14 +++++++------- 10 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ffbb3b62c17be..38ccce648502a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -138,7 +138,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): def on_validation_end(self, trainer, pl_module): from pytorch_lightning.trainer.states import TrainerState - if trainer.state is not TrainerState.FITTING or trainer.sanity_checking: + if trainer.state != TrainerState.FITTING or trainer.sanity_checking: return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ce021be28558a..9da0e13f203db 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -216,7 +216,7 @@ def save_checkpoint(self, trainer, pl_module): from pytorch_lightning.trainer.states import TrainerState if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state is not TrainerState.FITTING # don't save anything during non-fit + or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b72d679f9a637..d8aafb8abdfc1 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -222,7 +222,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if ( - self.lightning_module.trainer.state is TrainerState.FITTING + self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0 ): @@ -242,7 +242,7 @@ def __recover_child_process_weights(self, best_path, last_path): # todo, pass also best score # load last weights - if last_path is not None and self.lightning_module.trainer.state is TrainerState.FITTING: + if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index d2af09f4c1038..d74ae0f003864 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -208,7 +208,7 @@ def _skip_init_connections(self): Returns: Whether to skip initialization """ - return torch_distrib.is_initialized() and self.lightning_module.trainer.state is not TrainerState.FITTING + return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TainerState.FITTING def init_model_parallel_groups(self): num_model_parallel = 1 # TODO currently no support for vertical model parallel @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self): return self.world_size def handle_transferred_pipe_module(self) -> None: - if self.lightning_module.trainer.state is TrainerState.FITTING: + if self.lightning_module.trainer.state == TrainerState.FITTING: torch_distrib.barrier() # Ensure we await main process initialization # Add trainer/configure_optimizers to the pipe model for access in all worker processes rpc_pipe.PipeModel.trainer = self.lightning_module.trainer @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None: # Create pipe_module model = self.lightning_module self._find_and_init_pipe_module(model) - if self.lightning_module.trainer.state is TrainerState.FITTING: + if self.lightning_module.trainer.state == TrainerState.FITTING: torch_distrib.barrier() # Ensure we join main process initialization model.sequential_module.foreach_worker(register_optimizers, include_self=True) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index c6ea3c7a6cc84..7536ef9b1d856 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -49,7 +49,7 @@ def _reinit_optimizers_with_oss(self): trainer.convert_to_lightning_optimizers() def _wrap_optimizers(self): - if self.model.trainer.state is not TrainerState.FITTING: + if self.model.trainer.state != TrainerState.FITTING: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 0c0b79c177e46..7aadf797e160a 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -45,7 +45,7 @@ def _reinit_optimizers_with_oss(self): trainer.optimizers = optimizers def _wrap_optimizers(self): - if self.model.trainer.state is not TrainerState.FITTING: + if self.model.trainer.state != TrainerState.FITTING: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 081f3c38a4abb..a3ef7b4b1c89b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -141,7 +141,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if ( - self.lightning_module.trainer.state is TrainerState.FITTING + self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0 ): @@ -246,7 +246,7 @@ def post_dispatch(self) -> None: # todo, pass also bets score # load last weights - if last_path and model.trainer.state is not TrainerState.FITTING: + if last_path and model.trainer.state != TrainerState.FITTING: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -259,7 +259,7 @@ def __load_weights_on_main_process(self) -> None: model = self.lightning_module # load weights if not interrupted - if on_colab_kaggle() and model.trainer.state is TrainerState.FITTING: + if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING: self.load_spawn_weights(model) self._model = model diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index f5527d3990471..b717c4cc71c6e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -309,7 +309,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics = {} batch_pbar_metrics = {} batch_log_metrics = {} - is_train = self._stage is RunningStage.TRAINING + is_train = self._stage == RunningStage.TRAINING if not self._has_batch_loop_finished: # get pbar diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index c0481365a5eb2..d8cdd577282d3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -299,7 +299,7 @@ def get_evaluate_epoch_results(self): # log results of evaluation if ( - self.trainer.state is not TrainerState.FITTING + self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index de70d5091a607..8cbd53d93f37f 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -175,7 +175,7 @@ def state(self, state: TrainerState) -> None: @property def interrupted(self) -> bool: - return self._state is TrainerState.INTERRUPTED + return self._state == TrainerState.INTERRUPTED @property def is_global_zero(self) -> bool: @@ -423,7 +423,7 @@ def distributed_sampler_kwargs(self) -> Optional[dict]: @property def training(self) -> bool: - return self._running_stage is RunningStage.TRAINING + return self._running_stage == RunningStage.TRAINING @training.setter def training(self, val: bool) -> None: @@ -434,7 +434,7 @@ def training(self, val: bool) -> None: @property def testing(self) -> bool: - return self._running_stage is RunningStage.TESTING + return self._running_stage == RunningStage.TESTING @testing.setter def testing(self, val: bool) -> None: @@ -445,7 +445,7 @@ def testing(self, val: bool) -> None: @property def predicting(self) -> bool: - return self._running_stage is RunningStage.PREDICTING + return self._running_stage == RunningStage.PREDICTING @predicting.setter def predicting(self, val: bool) -> None: @@ -456,7 +456,7 @@ def predicting(self, val: bool) -> None: @property def tuning(self) -> bool: - return self._running_stage is RunningStage.TUNING + return self._running_stage == RunningStage.TUNING @tuning.setter def tuning(self, val: bool) -> None: @@ -467,7 +467,7 @@ def tuning(self, val: bool) -> None: @property def validating(self) -> bool: - return self._running_stage is RunningStage.VALIDATING + return self._running_stage == RunningStage.VALIDATING @validating.setter def validating(self, val: bool) -> None: @@ -482,7 +482,7 @@ def evaluating(self) -> bool: @property def sanity_checking(self) -> bool: - return self._running_stage is RunningStage.SANITY_CHECKING + return self._running_stage == RunningStage.SANITY_CHECKING @sanity_checking.setter def sanity_checking(self, val: bool) -> None: From 6c53cdc466705ced74c6738269ae5e9aae460783 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Mar 2021 13:35:10 +0100 Subject: [PATCH 26/34] Also for tests --- tests/trainer/test_states.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 1d6e4c295261e..edcf14b37a2ef 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -66,17 +66,17 @@ def on_test_batch_start(self, *_): assert not model.called else: assert model.called == {'train', 'validation'} - assert trainer.state is TrainerState.FINISHED + assert trainer.state == TrainerState.FINISHED model = TestModel(TrainerState.FITTING) trainer.fit(model) assert model.called == {'train', 'validation'} if fdr else {'train', 'sanity', 'validation'} - assert trainer.state is TrainerState.FINISHED + assert trainer.state == TrainerState.FINISHED model = TestModel(TrainerState.TESTING) trainer.test(model) assert model.called == {'test'} - assert trainer.state is TrainerState.FINISHED + assert trainer.state == TrainerState.FINISHED @pytest.mark.parametrize( From b64b46ec850b6a212b6c2ad104913f876480f264 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Mar 2021 15:36:00 +0100 Subject: [PATCH 27/34] Typo --- pytorch_lightning/plugins/training_type/rpc_sequential.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index e8b8b48eb3460..8fd75555ecd14 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -208,7 +208,7 @@ def _skip_init_connections(self): Returns: Whether to skip initialization """ - return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TainerState.FITTING + return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TrainerState.FITTING def init_model_parallel_groups(self): num_model_parallel = 1 # TODO currently no support for vertical model parallel From 7d427986a00010387978db090449d5189fe7aa88 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Mar 2021 15:57:39 +0100 Subject: [PATCH 28/34] Address @tchaton's comments --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 -- .../connectors/logger_connector/epoch_result_store.py | 11 ++++------- .../connectors/logger_connector/logger_connector.py | 4 ++-- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8efface0a9763..3dace06cbf825 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -205,7 +205,6 @@ def on_save(self, checkpoint: dict) -> dict: return checkpoint def transfer_distrib_spawn_state_on_fit_end(self, results): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -214,7 +213,6 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None - # TODO: is there a better way than accessing trainer through model -> trainer? if ( self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None @@ -229,7 +227,6 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(results) def __recover_child_process_weights(self, best_path, last_path): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a3ef7b4b1c89b..2118ff492ec26 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -131,7 +131,6 @@ def barrier(self, name: Optional[str] = None) -> None: rendezvous(f"pl.Trainer.{name}") def transfer_distrib_spawn_state_on_fit_end(self, results): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path if self.mp_queue is not None: @@ -139,7 +138,6 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None - # TODO: is there a better way than accessing trainer through model -> trainer? if ( self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index b717c4cc71c6e..12076e1fcc41b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -222,9 +222,8 @@ class EpochResultStore: ``` """ - def __init__(self, trainer, stage): + def __init__(self, trainer) -> None: self.trainer = trainer - self._stage = stage self.reset() def __getitem__(self, key: str) -> Any: @@ -309,7 +308,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics = {} batch_pbar_metrics = {} batch_log_metrics = {} - is_train = self._stage == RunningStage.TRAINING if not self._has_batch_loop_finished: # get pbar @@ -317,8 +315,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: logger_connector.add_progress_bar_metrics(batch_pbar_metrics) batch_log_metrics = self.get_latest_batch_log_metrics() - if is_train: - # Only log and add to callback epoch step during evaluation, test. + if self.trainer.state == TrainerState.FITTING: logger_connector._logged_metrics.update(batch_log_metrics) callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) @@ -341,7 +338,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: # TODO(carmocca): when we implement flushing the logger connector metrics after # the trainer.state changes, this should check trainer.evaluating instead - if not is_train and self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): logger_connector.evaluation_callback_metrics.update(callback_metrics) # update callback_metrics @@ -486,4 +483,4 @@ def __call__( return result def __repr__(self): - return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" + return f"{self.__class__.__name__}(internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ac458b15b1268..2c6a0d613e648 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -40,8 +40,8 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._logged_metrics = MetricsHolder() self._progress_bar_metrics = MetricsHolder(to_float=True) self.eval_loop_results = [] - self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage} - self._cached_results[None] = EpochResultStore(trainer, None) + self._cached_results = {stage: EpochResultStore(trainer) for stage in RunningStage} + self._cached_results[None] = EpochResultStore(trainer) self._callback_hook_validator = CallbackHookNameValidator() @property From 7a3f8cdb09d35af5e12a614b105b49514af463e6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Mar 2021 16:00:11 +0100 Subject: [PATCH 29/34] PEP8 --- .../trainer/connectors/logger_connector/epoch_result_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 12076e1fcc41b..262e721ed2b6c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -17,7 +17,7 @@ import torch from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType, LightningEnum From c0ef3fac1f388b1a6a13933ee29bd4e541deb0a2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Mar 2021 16:59:54 +0100 Subject: [PATCH 30/34] Correct property --- .../trainer/connectors/logger_connector/epoch_result_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 262e721ed2b6c..223216846758f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -315,7 +315,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: logger_connector.add_progress_bar_metrics(batch_pbar_metrics) batch_log_metrics = self.get_latest_batch_log_metrics() - if self.trainer.state == TrainerState.FITTING: + if self.trainer.training: logger_connector._logged_metrics.update(batch_log_metrics) callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) From 63c949322c2601d0fe479ba9be37b1fe206150bf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Mar 2021 17:15:22 +0100 Subject: [PATCH 31/34] Update CHANGELOG --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc1e0920f4b2f..06772b93773dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,9 +50,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Deprecated `trainer.tested_ckpt_path` in favor of `trainer.evaluated_ckpt_path` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) - - ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) From 10f7f21694d873ae16672142fafe04e7571a9c84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 6 Mar 2021 02:02:59 +0100 Subject: [PATCH 32/34] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- tests/trainer/flags/test_fast_dev_run.py | 4 ++-- tests/trainer/test_trainer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2118ff492ec26..efada181ca9a6 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -244,7 +244,7 @@ def post_dispatch(self) -> None: # todo, pass also bets score # load last weights - if last_path and model.trainer.state != TrainerState.FITTING: + if last_path and model.trainer.state == TrainerState.FITTING: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 3c679acffc191..09c5b58d363d9 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -20,8 +20,8 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - auto_scale_batch_size=tuner_alg == 'batch size scaler', - auto_lr_find=tuner_alg == 'learning rate finder', + auto_scale_batch_size=(tuner_alg == 'batch size scaler'), + auto_lr_find=(tuner_alg == 'learning rate finder'), fast_dev_run=True ) expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fc93383807381..5012ed2321480 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -601,7 +601,7 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) -def test_checkpoint_path(tmpdir, ckpt_path, save_top_k): +def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k): hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) From d1dc4c99a0232996ce65ae65bae35fe4de6645f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 6 Mar 2021 02:03:24 +0100 Subject: [PATCH 33/34] Update pytorch_lightning/trainer/trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2c14212820c0c..cc1964f07039b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -922,7 +922,7 @@ def __evaluate_using_weights( self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) if self.validating: - self.validated_ckpt_paath = ckpt_path + self.validated_ckpt_path = ckpt_path else: self.tested_ckpt_path = ckpt_path From 45a010f2a28f9cb43bf54b7110db50205c46e778 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 6 Mar 2021 02:11:22 +0100 Subject: [PATCH 34/34] Remove called sanity check --- tests/trainer/test_states.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index edcf14b37a2ef..bedaef6d1ffb8 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -32,50 +32,38 @@ def test_initialize_state(tmpdir): ) def test_trainer_state_while_running(tmpdir, extra_params): trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True) - fdr = trainer.fast_dev_run class TestModel(BoringModel): def __init__(self, expected_state): super().__init__() self.expected_state = expected_state - self.called = set() self.lr = 0.1 def on_batch_start(self, *_): assert self.trainer.state == self.expected_state def on_train_batch_start(self, *_): - self.called.add("train") assert self.trainer.training def on_sanity_check_start(self, *_): - self.called.add("sanity") assert self.trainer.sanity_checking def on_validation_batch_start(self, *_): - self.called.add("validation") assert self.trainer.validating or self.trainer.sanity_checking def on_test_batch_start(self, *_): - self.called.add("test") assert self.trainer.testing model = TestModel(TrainerState.TUNING) trainer.tune(model) - if fdr: - assert not model.called - else: - assert model.called == {'train', 'validation'} assert trainer.state == TrainerState.FINISHED model = TestModel(TrainerState.FITTING) trainer.fit(model) - assert model.called == {'train', 'validation'} if fdr else {'train', 'sanity', 'validation'} assert trainer.state == TrainerState.FINISHED model = TestModel(TrainerState.TESTING) trainer.test(model) - assert model.called == {'test'} assert trainer.state == TrainerState.FINISHED