From f934a98aac23f89ab23c8d71012b3cd74784d60a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 24 May 2021 18:36:38 +0200 Subject: [PATCH 1/7] Refactor some loops code and hook tests --- .../trainer/connectors/optimizer_connector.py | 21 +- pytorch_lightning/trainer/training_loop.py | 46 +- tests/models/test_hooks.py | 466 +++++++++++------- tests/trainer/loops/test_training_loop.py | 112 ----- 4 files changed, 313 insertions(+), 332 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index e7fbdf9b18c02..2797504288bd3 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -11,30 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import List, Optional +from weakref import proxy +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException class OptimizerConnector: - def __init__(self, trainer): - self.trainer = trainer + def __init__(self, trainer: 'pl.Trainer') -> None: + self.trainer = proxy(trainer) - def on_trainer_init(self): + def on_trainer_init(self) -> None: self.trainer.lr_schedulers = [] self.trainer.optimizers = [] self.trainer.optimizer_frequencies = [] - def update_learning_rates( - self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None - ): + def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None: """Update learning rates. Args: interval: either 'epoch' or 'step'. - monitor_metrics: dict of possible values to monitor + opt_indices: indices of the optimizers to update. """ if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: return @@ -55,10 +55,7 @@ def update_learning_rates( monitor_key, monitor_val = None, None if lr_scheduler['reduce_on_plateau']: monitor_key = lr_scheduler['monitor'] - monitor_val = ( - monitor_metrics.get(monitor_key) if monitor_metrics is not None else - self.trainer.logger_connector.callback_metrics.get(monitor_key) - ) + monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key) if monitor_val is None: if lr_scheduler.get('strict', True): avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys()) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a555146875eb5..9fe4b5640c385 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import contextmanager, suppress -from copy import copy, deepcopy +from copy import copy from functools import partial, update_wrapper from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -478,7 +478,6 @@ def run_training_epoch(self): train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - batch_idx = None is_last_batch = None @@ -525,8 +524,7 @@ def run_training_epoch(self): self.save_loggers_on_train_batch_end() # update LR schedulers - monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.update_lr_schedulers('step') self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training @@ -567,7 +565,7 @@ def run_training_epoch(self): # update epoch level lr_schedulers if no val loop outside train loop is triggered if not should_check_val or should_train_only: - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + self.update_lr_schedulers('epoch') if should_train_only: self.check_checkpoint_callback(True) @@ -864,17 +862,16 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): # track gradients result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) - def update_train_loop_lr_schedulers(self, monitor_metrics=None): - num_accumulated_batches_reached = self._accumulated_batches_reached() - num_training_batches_reached = self._num_training_batches_reached() - - if num_accumulated_batches_reached or num_training_batches_reached: - # update lr - self.trainer.optimizer_connector.update_learning_rates( - interval="step", - monitor_metrics=monitor_metrics, - opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()], - ) + def update_lr_schedulers(self, interval: str) -> None: + if interval == "step": + finished_accumulation = self._accumulated_batches_reached() + finished_epoch = self._num_training_batches_reached() + if not finished_accumulation and not finished_epoch: + return + self.trainer.optimizer_connector.update_learning_rates( + interval=interval, + opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()], + ) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() @@ -898,14 +895,20 @@ def should_accumulate(self): def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: """ Decide if we should run validation. """ - if not self.trainer.enable_validation: return False - # check if this epoch is eligible to run validation - if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + if not is_val_check_epoch: return False + is_infinite_dataset = self.trainer.val_check_batch == float('inf') + if is_last_batch and is_infinite_dataset: + return True + + if self.trainer.should_stop: + return True + # val_check_batch is inf for iterable datasets with no length defined # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = False @@ -916,12 +919,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo # Note: num_training_batches is also inf for iterable datasets with no length defined epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 - is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") if on_epoch: - return ( - is_val_check_batch and epoch_end_val_check - ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + return is_val_check_batch and epoch_end_val_check else: return is_val_check_batch and not epoch_end_val_check diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 78f8d2c0a94e9..79b52ef4d5d29 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -249,186 +249,227 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): assert trainer.global_step == (batch_idx_ + 1) * max_epochs -def test_trainer_model_hook_system(tmpdir): - """Test the LightningModule hook system.""" - - class HookedModel(BoringModel): - - def __init__(self): - super().__init__() - self.called = [] - - def on_after_backward(self): - self.called.append("on_after_backward") - super().on_after_backward() - - def on_before_zero_grad(self, *args, **kwargs): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(*args, **kwargs) - - def on_epoch_start(self): - self.called.append("on_epoch_start") - super().on_epoch_start() - - def on_epoch_end(self): - self.called.append("on_epoch_end") - super().on_epoch_end() - - def on_fit_start(self): - self.called.append("on_fit_start") - super().on_fit_start() - - def on_fit_end(self): - self.called.append("on_fit_end") - super().on_fit_end() - - def on_hpc_load(self, *args, **kwargs): - self.called.append("on_hpc_load") - super().on_hpc_load(*args, **kwargs) - - def on_hpc_save(self, *args, **kwargs): - self.called.append("on_hpc_save") - super().on_hpc_save(*args, **kwargs) - - def on_load_checkpoint(self, *args, **kwargs): - self.called.append("on_load_checkpoint") - super().on_load_checkpoint(*args, **kwargs) - - def on_save_checkpoint(self, *args, **kwargs): - self.called.append("on_save_checkpoint") - super().on_save_checkpoint(*args, **kwargs) - - def on_pretrain_routine_start(self): - self.called.append("on_pretrain_routine_start") - super().on_pretrain_routine_start() +class HookedModel(BoringModel): + + def __init__(self): + super().__init__() + self.called = [] + self.train_batch = [ + 'on_train_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'training_step', + 'on_before_zero_grad', + 'optimizer_zero_grad', + 'backward', + 'on_after_backward', + 'optimizer_step', + 'on_train_batch_end', + ] + self.val_batch = [ + 'on_validation_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_validation_batch_end', + ] + + def training_step(self, *args, **kwargs): + self.called.append("training_step") + return super().training_step(*args, **kwargs) + + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) + + def optimizer_zero_grad(self, *args, **kwargs): + self.called.append("optimizer_zero_grad") + super().optimizer_zero_grad(*args, **kwargs) + + def training_epoch_end(self, *args, **kwargs): + self.called.append("training_epoch_end") + super().training_epoch_end(*args, **kwargs) + + def backward(self, *args, **kwargs): + self.called.append("backward") + super().backward(*args, **kwargs) + + def on_after_backward(self): + self.called.append("on_after_backward") + super().on_after_backward() + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + self.called.append("optimizer_step") # append after as closure calls other methods + + def validation_epoch_end(self, *args, **kwargs): + self.called.append("validation_epoch_end") + super().validation_epoch_end(*args, **kwargs) + + def on_epoch_start(self): + self.called.append("on_epoch_start") + super().on_epoch_start() + + def on_epoch_end(self): + self.called.append("on_epoch_end") + super().on_epoch_end() + + def on_fit_start(self): + self.called.append("on_fit_start") + super().on_fit_start() + + def on_fit_end(self): + self.called.append("on_fit_end") + super().on_fit_end() + + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) + + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) + + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) + + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) + + def on_pretrain_routine_start(self): + self.called.append("on_pretrain_routine_start") + super().on_pretrain_routine_start() + + def on_pretrain_routine_end(self): + self.called.append("on_pretrain_routine_end") + super().on_pretrain_routine_end() + + def on_train_start(self): + self.called.append("on_train_start") + super().on_train_start() + + def on_train_end(self): + self.called.append("on_train_end") + super().on_train_end() + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) - def on_pretrain_routine_end(self): - self.called.append("on_pretrain_routine_end") - super().on_pretrain_routine_end() + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) - def on_train_start(self): - self.called.append("on_train_start") - super().on_train_start() + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) - def on_train_end(self): - self.called.append("on_train_end") - super().on_train_end() + def on_train_epoch_start(self): + self.called.append("on_train_epoch_start") + super().on_train_epoch_start() - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) + def on_train_epoch_end(self): + self.called.append("on_train_epoch_end") + super().on_train_epoch_end() - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) + def on_validation_start(self): + self.called.append("on_validation_start") + super().on_validation_start() - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) + def on_validation_end(self): + self.called.append("on_validation_end") + super().on_validation_end() - def on_train_batch_start(self, *args, **kwargs): - self.called.append("on_train_batch_start") - super().on_train_batch_start(*args, **kwargs) + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) - def on_train_batch_end(self, *args, **kwargs): - self.called.append("on_train_batch_end") - super().on_train_batch_end(*args, **kwargs) + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) - def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") - super().on_train_epoch_start() - - def on_train_epoch_end(self): - self.called.append("on_train_epoch_end") - super().on_train_epoch_end() - - def on_validation_start(self): - self.called.append("on_validation_start") - super().on_validation_start() - - def on_validation_end(self): - self.called.append("on_validation_end") - super().on_validation_end() + def on_validation_epoch_start(self): + self.called.append("on_validation_epoch_start") + super().on_validation_epoch_start() - def on_validation_batch_start(self, *args, **kwargs): - self.called.append("on_validation_batch_start") - super().on_validation_batch_start(*args, **kwargs) + def on_validation_epoch_end(self, *args, **kwargs): + self.called.append("on_validation_epoch_end") + super().on_validation_epoch_end(*args, **kwargs) - def on_validation_batch_end(self, *args, **kwargs): - self.called.append("on_validation_batch_end") - super().on_validation_batch_end(*args, **kwargs) + def on_test_start(self): + self.called.append("on_test_start") + super().on_test_start() - def on_validation_epoch_start(self): - self.called.append("on_validation_epoch_start") - super().on_validation_epoch_start() + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) - def on_validation_epoch_end(self, *args, **kwargs): - self.called.append("on_validation_epoch_end") - super().on_validation_epoch_end(*args, **kwargs) + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) - def on_test_start(self): - self.called.append("on_test_start") - super().on_test_start() + def on_test_epoch_start(self): + self.called.append("on_test_epoch_start") + super().on_test_epoch_start() - def on_test_batch_start(self, *args, **kwargs): - self.called.append("on_test_batch_start") - super().on_test_batch_start(*args, **kwargs) + def on_test_epoch_end(self, *args, **kwargs): + self.called.append("on_test_epoch_end") + super().on_test_epoch_end(*args, **kwargs) - def on_test_batch_end(self, *args, **kwargs): - self.called.append("on_test_batch_end") - super().on_test_batch_end(*args, **kwargs) + def on_validation_model_eval(self): + self.called.append("on_validation_model_eval") + super().on_validation_model_eval() - def on_test_epoch_start(self): - self.called.append("on_test_epoch_start") - super().on_test_epoch_start() + def on_validation_model_train(self): + self.called.append("on_validation_model_train") + super().on_validation_model_train() - def on_test_epoch_end(self, *args, **kwargs): - self.called.append("on_test_epoch_end") - super().on_test_epoch_end(*args, **kwargs) + def on_test_model_eval(self): + self.called.append("on_test_model_eval") + super().on_test_model_eval() - def on_validation_model_eval(self): - self.called.append("on_validation_model_eval") - super().on_validation_model_eval() + def on_test_model_train(self): + self.called.append("on_test_model_train") + super().on_test_model_train() - def on_validation_model_train(self): - self.called.append("on_validation_model_train") - super().on_validation_model_train() + def on_test_end(self): + self.called.append("on_test_end") + super().on_test_end() - def on_test_model_eval(self): - self.called.append("on_test_model_eval") - super().on_test_model_eval() + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) - def on_test_model_train(self): - self.called.append("on_test_model_train") - super().on_test_model_train() + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage) - def on_test_end(self): - self.called.append("on_test_end") - super().on_test_end() - - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage) +def test_trainer_model_hook_system_fit(tmpdir): + """Test the LightningModule hook system.""" model = HookedModel() - - # fit model + train_batches = 2 + val_batches = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_val_batches=1, - limit_train_batches=2, - limit_test_batches=1, + limit_train_batches=train_batches, + limit_val_batches=val_batches, progress_bar_refresh_rate=0, weights_summary=None, ) - assert model.called == [] - trainer.fit(model) expected = [ 'setup_fit', @@ -439,11 +480,8 @@ def teardown(self, stage=None): 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_validation_batch_end', + *(model.val_batch * val_batches), + 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_validation_end', @@ -451,31 +489,16 @@ def teardown(self, stage=None): 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - 'on_train_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_zero_grad', - 'on_after_backward', - 'on_train_batch_end', - 'on_train_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_zero_grad', - 'on_after_backward', - 'on_train_batch_end', + *(model.train_batch * train_batches), + 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_validation_batch_end', + *(model.val_batch * val_batches), + 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_save_checkpoint', @@ -487,8 +510,51 @@ def teardown(self, stage=None): ] assert model.called == expected + +def test_trainer_model_hook_system_fit_no_val(tmpdir): model = HookedModel() + train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0, + limit_train_batches=train_batches, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + assert model.called == [] + trainer.fit(model) + expected = [ + 'setup_fit', + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'on_train_start', + 'on_epoch_start', + 'on_train_epoch_start', + *(model.train_batch * train_batches), + 'training_epoch_end', + 'on_train_epoch_end', + 'on_epoch_end', + 'on_save_checkpoint', # from train epoch end + 'on_train_end', + 'on_fit_end', + 'teardown_fit', + ] + assert model.called == expected + + +def test_trainer_model_hook_system_validate(tmpdir): + model = HookedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + assert model.called == [] trainer.validate(model, verbose=False) expected = [ 'setup_validate', @@ -501,6 +567,7 @@ def teardown(self, stage=None): 'transfer_batch_to_device', 'on_after_batch_transfer', 'on_validation_batch_end', + 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_validation_end', @@ -509,9 +576,18 @@ def teardown(self, stage=None): ] assert model.called == expected + +def test_trainer_model_hook_system_test(tmpdir): model = HookedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + assert model.called == [] trainer.test(model, verbose=False) - expected = [ 'setup_test', 'on_test_model_eval', @@ -647,30 +723,50 @@ def on_after_batch_transfer(self, *args, **kwargs): reload_dataloaders_every_epoch=True, ) trainer.fit(model, datamodule=dm) - expected = [ - 'prepare_data', 'setup_fit', 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'train_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', - 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', - 'teardown_fit' + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit', ] assert dm.called == expected dm = HookedDataModule() trainer.validate(model, datamodule=dm, verbose=False) - expected = [ - 'prepare_data', 'setup_validate', 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'teardown_validate' + 'prepare_data', + 'setup_validate', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_validate', ] assert dm.called == expected dm = HookedDataModule() trainer.test(model, datamodule=dm, verbose=False) - expected = [ - 'prepare_data', 'setup_test', 'test_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'teardown_test' + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test', ] assert dm.called == expected diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 2d32d8c8878e4..89a2ae0ec0d61 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -17,118 +17,6 @@ from tests.helpers import BoringModel -def test_training_loop_hook_call_order(tmpdir): - """Tests that hooks / methods called in the training loop are in the correct order as detailed in the docs: - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks""" - - class HookedModel(BoringModel): - - def __init__(self): - super().__init__() - self.called = [] - - def on_epoch_start(self): - self.called.append("on_epoch_start") - super().on_epoch_start() - - def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") - super().on_train_epoch_start() - - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append("on_train_batch_start") - super().on_train_batch_start(batch, batch_idx, dataloader_idx) - - def training_step(self, batch, batch_idx): - self.called.append("training_step") - return super().training_step(batch, batch_idx) - - def on_before_zero_grad(self, optimizer): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(optimizer) - - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): - self.called.append("optimizer_zero_grad") - super().optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx) - - def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs): - self.called.append("backward") - super().backward(loss, optimizer, optimizer_idx, *args, **kwargs) - - def on_after_backward(self): - self.called.append("on_after_backward") - super().on_after_backward() - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu, - using_native_amp, - using_lbfgs, - ): - super().optimizer_step( - epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs - ) - self.called.append("optimizer_step") # append after as closure calls other methods - - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append("on_train_batch_end") - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) - - def training_epoch_end(self, outputs): - self.called.append("training_epoch_end") - super().training_epoch_end(outputs) - - def on_train_epoch_end(self, outputs): - self.called.append("on_train_epoch_end") - super().on_train_epoch_end(outputs) - - def on_epoch_end(self): - self.called.append("on_epoch_end") - super().on_epoch_end() - - model = HookedModel() - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=1, - limit_train_batches=1, - limit_test_batches=1, - progress_bar_refresh_rate=0, - weights_summary=None, - ) - - assert model.called == [] - - trainer.fit(model) - expected = [ - "on_epoch_start", # validation - "on_epoch_end", - "on_epoch_start", # training - "on_train_epoch_start", - "on_train_batch_start", - "training_step", - "on_before_zero_grad", - "optimizer_zero_grad", - "backward", - "on_after_backward", - "optimizer_step", - "on_train_batch_end", - "training_epoch_end", - "on_train_epoch_end", - "on_epoch_end", - "on_epoch_start", # validation - "on_epoch_end", - ] - assert model.called == expected - - def test_outputs_format(tmpdir): """Tests that outputs objects passed to model hooks and methods are consistent and in the correct format.""" From dea039befec535c5ce7dc146fdc2f852dd02423c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 24 May 2021 18:44:42 +0200 Subject: [PATCH 2/7] Bad merge --- tests/models/test_hooks.py | 169 +++++++++++++++++++++++++------------ 1 file changed, 117 insertions(+), 52 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index f16d3b4b115f6..7e21e4dd382f3 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -229,58 +229,123 @@ def train_dataloader(self): trainer.fit(model) -def test_trainer_model_hook_system(tmpdir): - """Test the LightningModule hook system.""" - - class HookedModel(BoringModel): - - def __init__(self): - super().__init__() - self.called = [] - - def on_after_backward(self): - self.called.append("on_after_backward") - super().on_after_backward() - - def on_before_zero_grad(self, *args, **kwargs): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(*args, **kwargs) - - def on_epoch_start(self): - self.called.append("on_epoch_start") - super().on_epoch_start() - - def on_epoch_end(self): - self.called.append("on_epoch_end") - super().on_epoch_end() - - def on_fit_start(self): - self.called.append("on_fit_start") - super().on_fit_start() - - def on_fit_end(self): - self.called.append("on_fit_end") - super().on_fit_end() - - def on_hpc_load(self, *args, **kwargs): - self.called.append("on_hpc_load") - super().on_hpc_load(*args, **kwargs) - - def on_hpc_save(self, *args, **kwargs): - self.called.append("on_hpc_save") - super().on_hpc_save(*args, **kwargs) - - def on_load_checkpoint(self, *args, **kwargs): - self.called.append("on_load_checkpoint") - super().on_load_checkpoint(*args, **kwargs) - - def on_save_checkpoint(self, *args, **kwargs): - self.called.append("on_save_checkpoint") - super().on_save_checkpoint(*args, **kwargs) - - def on_pretrain_routine_start(self): - self.called.append("on_pretrain_routine_start") - super().on_pretrain_routine_start() +class HookedModel(BoringModel): + + def __init__(self): + super().__init__() + self.called = [] + self.train_batch = [ + 'on_train_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'training_step', + 'on_before_zero_grad', + 'optimizer_zero_grad', + 'backward', + 'on_after_backward', + 'optimizer_step', + 'on_train_batch_end', + ] + self.val_batch = [ + 'on_validation_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_validation_batch_end', + ] + + def training_step(self, *args, **kwargs): + self.called.append("training_step") + return super().training_step(*args, **kwargs) + + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) + + def optimizer_zero_grad(self, *args, **kwargs): + self.called.append("optimizer_zero_grad") + super().optimizer_zero_grad(*args, **kwargs) + + def training_epoch_end(self, *args, **kwargs): + self.called.append("training_epoch_end") + super().training_epoch_end(*args, **kwargs) + + def backward(self, *args, **kwargs): + self.called.append("backward") + super().backward(*args, **kwargs) + + def on_after_backward(self): + self.called.append("on_after_backward") + super().on_after_backward() + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + self.called.append("optimizer_step") # append after as closure calls other methods + + def validation_epoch_end(self, *args, **kwargs): + self.called.append("validation_epoch_end") + super().validation_epoch_end(*args, **kwargs) + + def on_epoch_start(self): + self.called.append("on_epoch_start") + super().on_epoch_start() + + def on_epoch_end(self): + self.called.append("on_epoch_end") + super().on_epoch_end() + + def on_fit_start(self): + self.called.append("on_fit_start") + super().on_fit_start() + + def on_fit_end(self): + self.called.append("on_fit_end") + super().on_fit_end() + + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) + + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) + + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) + + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) + + def on_pretrain_routine_start(self): + self.called.append("on_pretrain_routine_start") + super().on_pretrain_routine_start() + + def on_pretrain_routine_end(self): + self.called.append("on_pretrain_routine_end") + super().on_pretrain_routine_end() + + def on_train_start(self): + self.called.append("on_train_start") + super().on_train_start() + + def on_train_end(self): + self.called.append("on_train_end") + super().on_train_end() + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) def on_train_batch_start(self, *args, **kwargs): self.called.append("on_train_batch_start") From e9db0d91eafa96a55e82214b4d1ce41a0e70f757 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 00:21:00 +0200 Subject: [PATCH 3/7] Fix test --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1ffd44fc09f10..58ae77070968f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -902,7 +902,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo return False is_infinite_dataset = self.trainer.val_check_batch == float('inf') - if is_last_batch and is_infinite_dataset: + if on_epoch and is_last_batch and is_infinite_dataset: return True if self.trainer.should_stop: From 8e6e3a2abfd8df2d8be5822a81522147a47b4ca8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 00:30:38 +0200 Subject: [PATCH 4/7] Bad merge --- tests/models/test_hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 110a882810e4b..d87631676889f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -437,7 +437,6 @@ def teardown(self, stage=None): def test_trainer_model_hook_system_fit(tmpdir): - """Test the LightningModule hook system.""" model = HookedModel() train_batches = 2 val_batches = 2 From 52b93d379c2a883811526deebf204195497ff115 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 00:31:55 +0200 Subject: [PATCH 5/7] Whitespace --- tests/models/test_hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d87631676889f..678f34d2984cf 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -502,7 +502,6 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): weights_summary=None, ) assert model.called == [] - trainer.fit(model) expected = [ 'setup_fit', From 50a0ab2a5d7724b01349cd441ad234c0d7e43994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 25 May 2021 12:23:26 +0200 Subject: [PATCH 6/7] Update pytorch_lightning/trainer/training_loop.py --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 58ae77070968f..49af6a5475657 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -905,7 +905,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo if on_epoch and is_last_batch and is_infinite_dataset: return True - if self.trainer.should_stop: + if on_epoch and self.trainer.should_stop: return True # val_check_batch is inf for iterable datasets with no length defined From dd40f0440c227cb1ef49d6680217f8caba5ea8ab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 12:55:23 +0200 Subject: [PATCH 7/7] Move comment and minor test changes --- pytorch_lightning/trainer/training_loop.py | 2 +- tests/callbacks/test_early_stopping.py | 2 +- tests/loggers/test_tensorboard.py | 49 ++++++---------------- 3 files changed, 14 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 49af6a5475657..3426ebc7e2dd9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -901,6 +901,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo if not is_val_check_epoch: return False + # val_check_batch is inf for iterable datasets with no length defined is_infinite_dataset = self.trainer.val_check_batch == float('inf') if on_epoch and is_last_batch and is_infinite_dataset: return True @@ -908,7 +909,6 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo if on_epoch and self.trainer.should_stop: return True - # val_check_batch is inf for iterable datasets with no length defined # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = False if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 8f89bedeb4f38..b1242de725c7f 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -157,7 +157,7 @@ def test_early_stopping_patience_train( """Test to ensure that early stopping is not triggered before patience is exhausted.""" class ModelOverrideTrainReturn(BoringModel): - train_return_values = torch.Tensor(loss_values) + train_return_values = torch.tensor(loss_values) def training_epoch_end(self, outputs): loss = self.train_return_values[self.current_epoch] diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index f22cdcfe2bba4..f7fe1c3bfd47e 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -264,67 +264,42 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): @mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics') -@pytest.mark.parametrize('expected', [ - ([5, 11, 17]), -]) -def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir): - """ - Tests to ensure that tensorboard log properly when accumulated_gradients > 1 - """ +def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir): + """Tests to ensure that tensorboard log properly when accumulated_gradients > 1""" class TestModel(BoringModel): def __init__(self): super().__init__() - self._count = 0 - self._indexes = [] - - def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('count', self._count, on_step=True, on_epoch=True) - self.log('loss', loss, on_step=True, on_epoch=True) + self.indexes = [] + def training_step(self, *args): + self.log('foo', 1, on_step=True, on_epoch=True) if not self.trainer.train_loop.should_accumulate(): if self.trainer.logger_connector.should_update_logs: - self._indexes.append(self.trainer.global_step) - - return loss - - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('val_loss', loss, on_step=True, on_epoch=True) - return loss - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] + self.indexes.append(self.trainer.global_step) + return super().training_step(*args) model = TestModel() model.training_epoch_end = None - model.validation_epoch_end = None - logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False) - trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=12, limit_val_batches=0, max_epochs=3, - gpus=0, accumulate_grad_batches=2, logger=[logger_0], log_every_n_steps=3, ) trainer.fit(model) - mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]] - assert mock_count_epochs == expected + calls = [m[2] for m in mock_log_metrics.mock_calls] + count_epochs = [c["step"] for c in calls if "foo_epoch" in c["metrics"]] + assert count_epochs == [5, 11, 17] - mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]] - assert model._indexes == mock_count_steps + count_steps = [c["step"] for c in calls if "foo_step" in c["metrics"]] + assert count_steps == model.indexes @mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter')