diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 77f30219ba8c0..8bb335f2e7847 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -52,6 +52,10 @@ def __init__(self, def setup(self, model): pass + def train(self): + self.trainer.setup_trainer(self.trainer.model) + return self.train_or_test() + def teardown(self): # Ensure if necessary all processes are finished self.barrier() @@ -66,6 +70,7 @@ def train_or_test(self): if self.trainer.testing: results = self.trainer.run_test() else: + self.trainer.train_loop.setup_training() results = self.trainer.train() return results diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index e034b209bf34c..997a3568daf2d 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -50,16 +50,6 @@ def setup(self, model): self.trainer.model = model - def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - return results - def _step(self, model_step: Callable, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 68af3f579a6e8..373406589d855 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -186,9 +186,6 @@ def ddp_train(self, process_idx, mp_queue, model): self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) @@ -198,8 +195,7 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index f0d9f2171bf48..0fde9da158c94 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -285,9 +285,6 @@ def ddp_train(self, process_idx, model): # allow for lr schedulers as well self.setup_optimizers(model) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) @@ -297,9 +294,8 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine self.barrier('ddp_setup') - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index e7ef38c8df3b4..f9ccaa200bbf4 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -146,9 +146,6 @@ def ddp_train(self, process_idx, mp_queue, model): self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) @@ -158,8 +155,7 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index c25e082ee348d..bdc4631b5d017 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -177,9 +177,6 @@ def ddp_train(self, process_idx, model): self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) @@ -189,8 +186,7 @@ def ddp_train(self, process_idx, model): # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 23783fada72f1..eb4ff24e39dd4 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -161,9 +161,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 self.ddp_plugin.on_after_setup_optimizers(self.trainer) - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - # 16-bit model = self.trainer.precision_connector.connect(model) @@ -173,8 +170,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = self.configure_ddp(model, device_ids) - # set up training routine - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index fc01c4686f04f..7517c774f51dd 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -101,16 +101,6 @@ def __init_nvidia_apex(self, model): return model - def train(self): - model = self.trainer.model - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - return results - def teardown(self): # replace the original fwd function self.trainer.model.forward = self.model_autocast_original_forward diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 49f21e9e34816..d65b19bbd9bb1 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -56,16 +56,6 @@ def setup(self, model): self.trainer.model = model - def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - return results - def _step(self, model_step: Callable, args): args[0] = self.to_device(args[0]) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 2013d75df7b1e..6e11a13064513 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -104,8 +104,7 @@ def train(self): # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - # set up training routine - self.trainer.train_loop.setup_training(self.trainer.model) + self.trainer.setup_trainer(self.trainer.model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 7dcfaae401ca7..286004bc0976e 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -134,8 +134,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # setup TPU training self.__setup_tpu_training(model, trainer) - # set up training routine - self.trainer.train_loop.setup_training(model) + self.trainer.setup_trainer(model) # train or test results = self.train_or_test() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fc9c70ba46d2e..03d46132fb177 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from pathlib import Path import re +from pathlib import Path from typing import Optional, Union import torch @@ -44,7 +44,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self, model: LightningModule) -> None: + def restore_weights(self) -> None: """ Attempt to restore a checkpoint (e.g. weights) in this priority: 1. from HPC weights @@ -64,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None: rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: + elif self.trainer.resume_from_checkpoint is not None: self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) # wait for all to catch up diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 6cf020aa65fa1..84e8a1bc68f05 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -208,9 +208,9 @@ def add_progress_bar_metrics(self, metrics): self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode): + def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result): self._track_callback_metrics(deprecated_eval_results, using_eval_result) - self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) + self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results) def evaluation_epoch_end(self, testing): # reset dataloader idx @@ -239,7 +239,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self, test_mode): + def get_evaluate_epoch_results(self): if not self.trainer.running_sanity_check: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() @@ -249,7 +249,7 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() # log results of test - if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): print(f'DATALOADER:{result_idx} TEST RESULTS') @@ -330,7 +330,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric if len(dataloader_result_metrics) > 0: self.eval_loop_results.append(dataloader_result_metrics) - def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode): + def __process_eval_epoch_end_results_and_log_legacy(self, eval_results): if self.trainer.running_sanity_check: return @@ -350,7 +350,7 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod callback_metrics = result.callback_metrics # in testing we don't need the callback metrics - if test_mode: + if self.trainer.testing: callback_metrics = {} else: _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4b70917c8c43d..63f65bead2579 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -24,7 +24,6 @@ class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer - self.testing = False self.outputs = [] self.step_metrics = [] self.predictions = None @@ -52,7 +51,7 @@ def get_evaluation_dataloaders(self, max_batches): model = self.trainer.get_model() # select dataloaders - if self.testing: + if self.trainer.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders @@ -85,34 +84,34 @@ def should_skip_evaluation(self, dataloaders, max_batches): return False def on_evaluation_start(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *args, **kwargs): model_ref = self.trainer.get_model() - if self.testing: + if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *args, **kwargs): model_ref = self.trainer.get_model() - if self.testing: + if self.trainer.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): model = self.trainer.get_model() - if self.testing: + if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) @@ -123,9 +122,6 @@ def is_using_eval_results(self): return using_eval_result def setup(self, model, max_batches, dataloaders): - # copy properties for forward overrides - self.trainer.model_connector.copy_trainer_model_properties(model) - # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) @@ -138,17 +134,23 @@ def setup(self, model, max_batches, dataloaders): self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - def build_args(self, test_mode, batch, batch_idx, dataloader_idx): + def _build_args(self, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - multiple_val_loaders = (not test_mode and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) - multiple_test_loaders = (test_mode and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) + multiple_val_loaders = ( + not self.trainer.testing + and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 + ) + multiple_test_loaders = ( + self.trainer.testing + and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1 + ) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) @@ -163,14 +165,14 @@ def _get_num_dataloaders(self, dataloaders): length = len(dataloaders[0]) return length - def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): + def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args - args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) + args = self._build_args(batch, batch_idx, dataloader_idx) model_ref = self.trainer.get_model() model_ref._results = Result() # run actual test step - if self.testing: + if self.trainer.testing: model_ref._current_fx_name = "test_step" output = self.trainer.accelerator_backend.test_step(args) else: @@ -192,7 +194,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): return output def evaluation_step_end(self, *args, **kwargs): - if self.testing: + if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) @@ -200,7 +202,7 @@ def evaluation_step_end(self, *args, **kwargs): def evaluation_epoch_end(self): # unset dataloder_idx in model - self.trainer.logger_connector.evaluation_epoch_end(self.testing) + self.trainer.logger_connector.evaluation_epoch_end(self.trainer.testing) using_eval_result = self.is_using_eval_results() @@ -216,7 +218,7 @@ def evaluation_epoch_end(self): def log_epoch_metrics_on_evaluation_end(self): # get the final loop results - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): @@ -230,7 +232,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): user_reduced = False - if self.testing: + if self.trainer.testing: if is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) @@ -250,7 +252,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): self.trainer.logger_connector.cache_logged_metrics() # depre warning if eval_results is not None and user_reduced: - step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' + step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' @@ -263,7 +265,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): eval_results = [eval_results] # track depreceated metrics - self.trainer.logger_connector.track_metrics_deprecated(eval_results, using_eval_result, self.testing) + self.trainer.logger_connector.track_metrics_deprecated(eval_results, using_eval_result) return eval_results @@ -300,15 +302,15 @@ def __auto_reduce_result_objs(self, outputs): def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( - self.testing, batch, dataloader_idx, self.num_dataloaders) + self.trainer.testing, batch, dataloader_idx, self.num_dataloaders) - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) @@ -319,16 +321,16 @@ def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): def store_predictions(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: - do_write_predictions = isinstance(output, Result) and self.testing + do_write_predictions = isinstance(output, Result) and self.trainer.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics - self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) + self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook - if self.testing: + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2c1867a21552d..c3ef0e507789e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,9 +15,9 @@ """Trainer to automate the training.""" import os +import warnings from pathlib import Path from typing import Dict, Iterable, List, Optional, Union -import warnings import torch from torch.utils.data import DataLoader @@ -57,7 +57,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -412,6 +412,46 @@ def __init__( # Callback system self.on_init_end() + def setup_trainer(self, model: LightningModule): + """ + Sanity check a few things before starting actual training or testing. + + Args: + model: The model to run sanity test on. + """ + # -------------------------- + # Setup?? + # -------------------------- + ref_model = self.get_model() + + # set the ranks and devices + self.accelerator_backend.dist.rank = self.global_rank + self.accelerator_backend.dist.device = ref_model.device + + # set local properties on the model + self.model_connector.copy_trainer_model_properties(model) + + # init amp. Must be done here instead of __init__ to allow ddp to work + if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: + self.scaler = self.precision_connector.backend.scaler + + # log hyper-parameters + if self.logger is not None: + # save exp to get started (this is where the first experiment logs are written) + self.logger.log_hyperparams(ref_model.hparams_initial) + self.logger.log_graph(ref_model) + self.logger.save() + + # wait for all to join if on distributed + self.accelerator_backend.barrier("setup_trainer") + + # register auto-resubmit when on SLURM + self.slurm_connector.register_slurm_signal_handlers() + + # track model now. + # if cluster resets state, the model will update with the saved weights + self.model = model + def fit( self, model: LightningModule, @@ -446,10 +486,6 @@ def fit( # hook self.data_connector.prepare_data(model) - # bookkeeping - # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -554,13 +590,13 @@ def train(self): # hook self.train_loop.on_train_end() - def run_evaluation(self, test_mode: bool = False, max_batches=None): + def run_evaluation(self, max_batches=None): # used to know if we are logging for val, test + reset cached results - self.logger_connector.set_stage(test_mode, reset=True) + self.logger_connector.set_stage(self.testing, reset=True) # bookkeeping - self.evaluation_loop.testing = test_mode + self.evaluation_loop.testing = self.testing # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) @@ -606,7 +642,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): - output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) + output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions @@ -659,7 +695,7 @@ def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) with self.profiler.profile("run_test_evaluation"): - eval_loop_results, _ = self.run_evaluation(test_mode=True) + eval_loop_results, _ = self.run_evaluation() if len(eval_loop_results) == 0: return 1 @@ -690,7 +726,7 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_start() # run eval step - _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches) + _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: @@ -794,11 +830,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path self.testing = True - os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False - del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3c8a8d45d0411..47e254606af93 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -124,64 +124,26 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) - def setup_training(self, model: LightningModule): - """Sanity check a few things before starting actual training. - - Args: - model: The model to run sanity test on. + def setup_training(self): + """ + Sanity check a few things before starting actual training. """ - # -------------------------- - # Setup?? - # -------------------------- - ref_model = model - if self.trainer.data_parallel: - ref_model = model.module - - # set the ranks and devices - self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank - self.trainer.accelerator_backend.dist.device = ref_model.device - - # give model convenience properties - ref_model.trainer = self.trainer - - # set local properties on the model - self.trainer.model_connector.copy_trainer_model_properties(ref_model) - - # init amp. Must be done here instead of __init__ to allow ddp to work - if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: - self.trainer.scaler = self.trainer.precision_connector.backend.scaler - - # log hyper-parameters - if self.trainer.logger is not None: - # save exp to get started (this is where the first experiment logs are written) - self.trainer.logger.log_hyperparams(ref_model.hparams_initial) - self.trainer.logger.log_graph(ref_model) - self.trainer.logger.save() - - # wait for all to join if on distributed - self.trainer.accelerator_backend.barrier("setup_training") - - # register auto-resubmit when on SLURM - self.trainer.slurm_connector.register_slurm_signal_handlers() - # -------------------------- # Pre-train # -------------------------- + ref_model = self.trainer.get_model() + # on pretrain routine start self.trainer.on_pretrain_routine_start(ref_model) if self.trainer.is_function_implemented("on_pretrain_routine_start"): ref_model.on_pretrain_routine_start() # print model summary - if self.trainer.is_global_zero and not self.trainer.testing: + if self.trainer.is_global_zero: ref_model.summarize(mode=self.trainer.weights_summary) - # track model now. - # if cluster resets state, the model will update with the saved weights - self.trainer.model = model - # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights(model) + self.trainer.checkpoint_connector.restore_weights() # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) @@ -597,7 +559,7 @@ def run_training_epoch(self): # ----------------------------------------- should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: - self.trainer.run_evaluation(test_mode=False) + self.trainer.run_evaluation() # reset stage to train self.trainer.logger_connector.set_stage("train") diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 9264e2a49810d..c9fac5cc04a45 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -16,7 +16,7 @@ import time from collections import Counter from functools import wraps -from typing import Callable, Any, Optional +from typing import Any, Callable, Optional def enabled_only(fn: Callable): @@ -133,7 +133,7 @@ def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, self.saved_lr_scheduler_updates.append(loss_dict) @enabled_only - def track_eval_loss_history(self, test_mode, batch_idx, dataloader_idx, output): + def track_eval_loss_history(self, batch_idx, dataloader_idx, output): loss_dict = { 'sanity_check': self.trainer.running_sanity_check, 'dataloader_idx': dataloader_idx, @@ -142,7 +142,7 @@ def track_eval_loss_history(self, test_mode, batch_idx, dataloader_idx, output): 'output': output } - if test_mode: + if self.trainer.testing: self.saved_test_losses.append(loss_dict) else: self.saved_val_losses.append(loss_dict) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 53debcebeb7cd..c9baf0db6976d 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -109,8 +109,6 @@ def test_trainer_callback_system(torch_save): call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_fit_start(trainer, model), - call.on_pretrain_routine_start(trainer, model), - call.on_pretrain_routine_end(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index d286bbf3a9de6..64dc25101eae6 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,21 +13,21 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from unittest.mock import MagicMock from typing import Optional +from unittest.mock import MagicMock import pytest import torch from torch.utils.data import DataLoader, random_split -from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.model_utils import is_overridden from tests.base import EvalModelTemplate -from tests.base.datasets import TrialMNIST from tests.base.datamodules import TrialMNISTDataModule +from tests.base.datasets import TrialMNIST from tests.base.develop_utils import reset_seed -from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator -from pytorch_lightning.callbacks import ModelCheckpoint def test_can_prepare_data(tmpdir): @@ -170,14 +170,14 @@ def test_data_hooks_called_with_stage_kwarg(tmpdir): def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = TrialMNISTDataModule.add_argparse_args(parser) - args = parser.parse_args(['--data_dir', './my_data']) - assert args.data_dir == './my_data' + args = parser.parse_args(['--data_dir', str(tmpdir)]) + assert args.data_dir == str(tmpdir) def test_dm_init_from_argparse_args(tmpdir): parser = ArgumentParser() parser = TrialMNISTDataModule.add_argparse_args(parser) - args = parser.parse_args(['--data_dir', './my_data']) + args = parser.parse_args(['--data_dir', str(tmpdir)]) dm = TrialMNISTDataModule.from_argparse_args(args) dm.prepare_data() dm.setup() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 8a5d2f667bc32..5352e749c5e55 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -392,8 +392,6 @@ def on_test_model_train(self): expected = [ 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', 'on_test_model_eval', 'on_test_epoch_start', 'on_test_batch_start', diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 3c43b201f52e4..75e1ec7724967 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -18,7 +18,7 @@ from tests.base import BoringModel from tests.base.datamodules import TrialMNISTDataModule -from tests.base.models import ParityModuleRNN, BasicGAN +from tests.base.models import BasicGAN, ParityModuleRNN @pytest.mark.parametrize("modelclass", [ @@ -116,10 +116,10 @@ def test_torchscript_retain_training_state(): ParityModuleRNN, BasicGAN, ]) -def test_torchscript_properties(modelclass): +def test_torchscript_properties(tmpdir, modelclass): """ Test that scripted LightningModule has unnecessary methods removed. """ model = modelclass() - model.datamodule = TrialMNISTDataModule() + model.datamodule = TrialMNISTDataModule(tmpdir) script = model.to_torchscript() assert not hasattr(script, "datamodule") assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 9e2023d27d928..3a9a87f84e5d9 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -15,8 +15,9 @@ Tests to ensure that the training loop works with a dict """ import os -from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule from tests.base.deterministic_model import DeterministicModel @@ -43,7 +44,7 @@ def backward(self, loss, optimizer, optimizer_idx): # out are the results of the full loop # eval_results are output of _evaluate - out, eval_results = trainer.run_evaluation(test_mode=False) + out, eval_results = trainer.run_evaluation() assert len(out) == 1 assert len(eval_results) == 0 @@ -74,7 +75,7 @@ def test_validation_step_scalar_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - out, eval_results = trainer.run_evaluation(test_mode=False) + out, eval_results = trainer.run_evaluation() assert len(out) == 1 assert len(eval_results) == 2 assert eval_results[0] == 171 and eval_results[1] == 171 @@ -106,7 +107,7 @@ def test_validation_step_arbitrary_dict_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(eval_results) == 2 assert eval_results[0]['some'] == 171 @@ -144,7 +145,7 @@ def test_validation_step_dict_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 5 assert len(eval_results) == 2 @@ -186,7 +187,7 @@ def test_val_step_step_end_no_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(eval_results) == 0 @@ -218,7 +219,7 @@ def test_val_step_step_end(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 6 @@ -264,7 +265,7 @@ def test_no_val_step_end(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 6 assert len(eval_results) == 1 @@ -308,7 +309,7 @@ def test_full_val_loop(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + callback_metrics, eval_results = trainer.run_evaluation() assert len(callback_metrics) == 1 assert len(callback_metrics[0]) == 7 assert len(eval_results) == 1 diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index da08ffe710e75..53636bed66f56 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -292,7 +292,7 @@ def validation_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, - callbacks=[ModelCheckpoint(dirpath='val_loss')], + callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) trainer.fit(model) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 50463f5c4b5e2..2fc6cb60c7fb0 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -656,11 +656,11 @@ def configure_optimizers(self): assert model.called +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_step_with_optimizer_closure(tmpdir): """ Tests that `step` works with optimizer_closure """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): @@ -736,11 +736,11 @@ def configure_optimizers(self): assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean() +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir): """ Tests that `step` works with optimizer_closure and accumulated_grad """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def __init__(self): @@ -798,12 +798,12 @@ def configure_optimizers(self): assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2 +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir): """ Tests that `step` works with optimizer_closure and extra arguments """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def __init__(self): @@ -854,13 +854,13 @@ def configure_optimizers(self): step_mock.assert_has_calls(expected_calls) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir): """ Tests that `step` works with optimizer_closure and different accumulated_gradient frequency """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def __init__(self): @@ -933,6 +933,7 @@ def configure_optimizers(self): mock_adam_step.assert_has_calls(expected_calls) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -941,7 +942,6 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_ste """ Tests that `step` works with optimizer_closure and different accumulated_gradient frequency """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def __init__(self):