From 9c40fee7983158a9aa3f608259d87a18597ef53d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 21 Dec 2020 01:44:53 +0530 Subject: [PATCH 01/24] Seperate epoch validaton from step validation --- pytorch_lightning/callbacks/early_stopping.py | 5 ++++ pytorch_lightning/trainer/training_loop.py | 30 +++++++++++++++---- tests/callbacks/test_callbacks.py | 4 +-- .../test_checkpoint_callback_frequency.py | 3 +- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 3e15d8462350c..2a56afd3fc0c7 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -89,6 +89,7 @@ def __init__( self.stopped_epoch = 0 self.mode = mode self.warned_result_obj = False + self._last_global_step_called = -1 # Indicates, if eval results are used as basis for early stopping # It is set to False initially and overwritten, if eval results have been validated self.based_on_eval_results = False @@ -163,6 +164,10 @@ def on_validation_end(self, trainer, pl_module): if trainer.running_sanity_check: return + if self._last_global_step_called == trainer.global_step: + return + + self._last_global_step_called = trainer.global_step self._run_early_stopping_check(trainer, pl_module) def on_validation_epoch_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 47e254606af93..c33dceeccfec8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -522,7 +522,6 @@ def run_training_epoch(self): # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - should_check_val = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx @@ -606,8 +605,14 @@ def run_training_epoch(self): self.num_optimizers ) - # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + if should_check_val: + self.trainer.run_evaluation(test_mode=False) + # reset stage to train + self.trainer.logger_connector.set_stage("train") + + should_train_check = not self.trainer.enable_validation and (sum(self.trainer.num_val_batches) == 0) + self.check_checkpoint_callback(should_train_check) # increment the global step once # progress global step according to grads progress @@ -852,14 +857,27 @@ def should_accumulate(self): is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) - def should_check_val_fx(self, batch_idx, is_last_batch): + def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch - should_check_val = is_val_check_batch or self.trainer.should_stop is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) + epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches + + if on_epoch: + should_check_val = ( + can_check_val + and ((is_val_check_batch and epoch_end_val_check) + or self.trainer.should_stop + or is_last_batch_for_infinite_dataset) + ) + else: + should_check_val = ( + can_check_val + and is_val_check_batch + and not epoch_end_val_check + ) return should_check_val diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c9baf0db6976d..d6421877e80ad 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -86,6 +86,8 @@ def test_trainer_callback_system(torch_save): call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), + call.on_epoch_end(trainer, model), + call.on_train_epoch_end(trainer, model, ANY), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), @@ -93,8 +95,6 @@ def test_trainer_callback_system(torch_save): call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), - call.on_epoch_end(trainer, model), - call.on_train_epoch_end(trainer, model, ANY), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index f9686dce159dd..69a8407a88590 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -56,7 +56,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, default_root_dir=tmpdir, max_epochs=epochs, weights_summary=None, - val_check_interval=val_check_interval + val_check_interval=val_check_interval, + progress_bar_refresh_rate=0, ) trainer.fit(model) From ed6ebf1407a102ea55f3ba9f04d63c5cfeb94558 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 21 Dec 2020 01:57:34 +0530 Subject: [PATCH 02/24] update system --- pytorch_lightning/trainer/training_loop.py | 13 +++++++------ tests/models/test_hooks.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c33dceeccfec8..e7aeb4eaf3211 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -573,11 +573,12 @@ def run_training_epoch(self): self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training - if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: - accumulation_done = self._accumulated_batches_reached() - # Ensure accumulation across batches has completed before breaking loop - if accumulation_done: - break + if ( + self.trainer.max_steps is not None + and self.trainer.max_steps == self.trainer.global_step + 1 + and self._accumulated_batches_reached() + ): + break # end epoch early # stop when the flag is changed or we've gone past the amount @@ -588,7 +589,7 @@ def run_training_epoch(self): self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches - if (batch_idx + 1) >= self.trainer.num_training_batches: + if self._num_training_batches_reached(): break # progress global step according to grads progress diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5352e749c5e55..5653dc3903170 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -372,6 +372,8 @@ def on_test_model_train(self): 'on_after_backward', 'on_before_zero_grad', 'on_train_batch_end', + 'on_epoch_end', + 'on_train_epoch_end', 'on_validation_model_eval', 'on_validation_epoch_start', 'on_validation_batch_start', @@ -379,8 +381,6 @@ def on_test_model_train(self): 'on_validation_epoch_end', 'on_save_checkpoint', 'on_validation_model_train', - 'on_epoch_end', - 'on_train_epoch_end', 'on_train_end', 'on_fit_end', ] From 236b0526231f7bce8464b96b74f72fdc9d65f19b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 26 Dec 2020 15:43:26 +0530 Subject: [PATCH 03/24] test --- pytorch_lightning/trainer/trainer.py | 4 +--- pytorch_lightning/trainer/training_loop.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c3ef0e507789e..d9e3618b2af5b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,6 +18,7 @@ import warnings from pathlib import Path from typing import Dict, Iterable, List, Optional, Union +import warnings import torch from torch.utils.data import DataLoader @@ -562,9 +563,6 @@ def train(self): if self.max_steps and self.max_steps <= self.global_step: return - # update LR schedulers - self.optimizer_connector.update_learning_rates(interval='epoch') - # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e7aeb4eaf3211..8887fcc7de178 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -509,7 +509,6 @@ def tbptt_split_batch(self, batch): return splits def run_training_epoch(self): - # get model model = self.trainer.get_model() @@ -606,6 +605,9 @@ def run_training_epoch(self): self.num_optimizers ) + # update LR schedulers + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) if should_check_val: self.trainer.run_evaluation(test_mode=False) From 788203b3e39f5bc1ccf91e4c4655ab3dd7040fe9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 27 Dec 2020 18:27:04 +0530 Subject: [PATCH 04/24] baked logic in callbacks --- pytorch_lightning/callbacks/early_stopping.py | 23 ------------ .../callbacks/model_checkpoint.py | 26 ++++++++++++-- pytorch_lightning/trainer/training_loop.py | 36 ++++--------------- 3 files changed, 30 insertions(+), 55 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2a56afd3fc0c7..f9e41b6a265cd 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -89,10 +89,6 @@ def __init__( self.stopped_epoch = 0 self.mode = mode self.warned_result_obj = False - self._last_global_step_called = -1 - # Indicates, if eval results are used as basis for early stopping - # It is set to False initially and overwritten, if eval results have been validated - self.based_on_eval_results = False self.__init_monitor_mode() @@ -164,25 +160,6 @@ def on_validation_end(self, trainer, pl_module): if trainer.running_sanity_check: return - if self._last_global_step_called == trainer.global_step: - return - - self._last_global_step_called = trainer.global_step - self._run_early_stopping_check(trainer, pl_module) - - def on_validation_epoch_end(self, trainer, pl_module): - if trainer.fast_dev_run or trainer.running_sanity_check: - return - - if self._validate_condition_metric(trainer.callback_metrics): - # turn off early stopping in on_train_epoch_end - self.based_on_eval_results = True - - def on_train_epoch_end(self, trainer, pl_module, outputs): - # disable early stopping in train loop when there's a val loop - if self.based_on_eval_results: - return - self._run_early_stopping_check(trainer, pl_module) def _run_early_stopping_check(self, trainer, pl_module): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e5c960b3c002b..c2c4301135505 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,11 +20,11 @@ """ +from copy import deepcopy import numbers import os -import re -from copy import deepcopy from pathlib import Path +import re from typing import Any, Dict, Optional, Union import numpy as np @@ -202,6 +202,28 @@ def on_validation_end(self, trainer, pl_module): """ self.save_checkpoint(trainer, pl_module) + def on_train_epoch_end(self, trainer, pl_module, outputs): + """ + checkpoints can be saved at the end of the train loop if no validation is done + """ + should_train_check = not trainer.enable_validation and (sum(trainer.num_val_batches) == 0) + if should_train_check and trainer.checkpoint_connector.has_trained: + self.save_checkpoint(trainer, pl_module) + + def on_train_end(self, trainer, pl_module): + """ + checkpoints to be saved when training ends + """ + # when a checkpoint was saved at the last step + if trainer.checkpoint_connector.has_trained: + if self.save_last: + rank_zero_info('Saving latest checkpoint...') + + # need to temporarily decrease the global step to avoid saving duplicates + trainer.global_step -= 1 + self.save_checkpoint(trainer, pl_module) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8887fcc7de178..c0f5c74a023ef 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -18,7 +18,7 @@ import torch import torch.distributed as torch_distrib -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer @@ -156,12 +156,6 @@ def on_train_end(self): self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_save=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -182,19 +176,6 @@ def on_train_end(self): model.cpu() torch.cuda.empty_cache() - def check_checkpoint_callback(self, should_save, is_last=False): - # TODO bake this logic into the checkpoint callback - if should_save and self.trainer.checkpoint_connector.has_trained: - checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - - if is_last and any(c.save_last for c in checkpoint_callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.get_model() - - for callback in checkpoint_callbacks: - callback.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -594,10 +575,8 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - # epoch end hook - self.run_on_epoch_end_hook(epoch_output) - - # log epoch metrics + # inform logger the batch loop has finished and log epoch metrics + self.trainer.logger_connector.on_train_epoch_end() self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, @@ -608,15 +587,15 @@ def run_training_epoch(self): # update LR schedulers self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + # epoch end hook + self.run_on_epoch_end_hook(epoch_output) + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) if should_check_val: self.trainer.run_evaluation(test_mode=False) # reset stage to train self.trainer.logger_connector.set_stage("train") - should_train_check = not self.trainer.enable_validation and (sum(self.trainer.num_val_batches) == 0) - self.check_checkpoint_callback(should_train_check) - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() @@ -834,9 +813,6 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): - # inform logger the batch loop has finished - self.trainer.logger_connector.on_train_epoch_end() - self.trainer.call_hook('on_epoch_end') self.trainer.call_hook('on_train_epoch_end', epoch_output) From c7b24ca6b131990a55096409478f9d7d02d3b052 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 27 Dec 2020 20:55:07 +0530 Subject: [PATCH 05/24] unbake logic in callbacks --- .../callbacks/model_checkpoint.py | 22 -------- pytorch_lightning/trainer/training_loop.py | 54 +++++++++++++++---- tests/callbacks/test_callbacks.py | 4 +- tests/models/test_hooks.py | 4 +- 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index c2c4301135505..5f812b5d924b0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -202,28 +202,6 @@ def on_validation_end(self, trainer, pl_module): """ self.save_checkpoint(trainer, pl_module) - def on_train_epoch_end(self, trainer, pl_module, outputs): - """ - checkpoints can be saved at the end of the train loop if no validation is done - """ - should_train_check = not trainer.enable_validation and (sum(trainer.num_val_batches) == 0) - if should_train_check and trainer.checkpoint_connector.has_trained: - self.save_checkpoint(trainer, pl_module) - - def on_train_end(self, trainer, pl_module): - """ - checkpoints to be saved when training ends - """ - # when a checkpoint was saved at the last step - if trainer.checkpoint_connector.has_trained: - if self.save_last: - rank_zero_info('Saving latest checkpoint...') - - # need to temporarily decrease the global step to avoid saving duplicates - trainer.global_step -= 1 - self.save_checkpoint(trainer, pl_module) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c0f5c74a023ef..ae340451bca98 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -156,6 +156,12 @@ def on_train_end(self): self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -176,6 +182,28 @@ def on_train_end(self): model.cpu() torch.cuda.empty_cache() + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + + if is_last and any(cb.save_last for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.get_model() + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.get_model() + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -575,8 +603,16 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - # inform logger the batch loop has finished and log epoch metrics - self.trainer.logger_connector.on_train_epoch_end() + # epoch end hook + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + if should_check_val: + self.trainer.run_evaluation(test_mode=False) + # reset stage to train + self.trainer.logger_connector.set_stage("train") + + self.run_on_epoch_end_hook(epoch_output) + + # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, @@ -587,14 +623,9 @@ def run_training_epoch(self): # update LR schedulers self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - # epoch end hook - self.run_on_epoch_end_hook(epoch_output) - - should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) - if should_check_val: - self.trainer.run_evaluation(test_mode=False) - # reset stage to train - self.trainer.logger_connector.set_stage("train") + should_train_only_check = not self.trainer.enable_validation and (sum(self.trainer.num_val_batches) == 0) + self.check_checkpoint_callback(should_train_only_check) + self.check_early_stopping_callback(should_train_only_check) # increment the global step once # progress global step according to grads progress @@ -813,6 +844,9 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): + # inform logger the batch loop has finished + self.trainer.logger_connector.on_train_epoch_end() + self.trainer.call_hook('on_epoch_end') self.trainer.call_hook('on_train_epoch_end', epoch_output) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index d6421877e80ad..c9baf0db6976d 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -86,8 +86,6 @@ def test_trainer_callback_system(torch_save): call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), - call.on_epoch_end(trainer, model), - call.on_train_epoch_end(trainer, model, ANY), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), @@ -95,6 +93,8 @@ def test_trainer_callback_system(torch_save): call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), + call.on_epoch_end(trainer, model), + call.on_train_epoch_end(trainer, model, ANY), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5653dc3903170..5352e749c5e55 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -372,8 +372,6 @@ def on_test_model_train(self): 'on_after_backward', 'on_before_zero_grad', 'on_train_batch_end', - 'on_epoch_end', - 'on_train_epoch_end', 'on_validation_model_eval', 'on_validation_epoch_start', 'on_validation_batch_start', @@ -381,6 +379,8 @@ def on_test_model_train(self): 'on_validation_epoch_end', 'on_save_checkpoint', 'on_validation_model_train', + 'on_epoch_end', + 'on_train_epoch_end', 'on_train_end', 'on_fit_end', ] From 42b0c7bd861fe64f756d05f89fcb51154cbb5fec Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 31 Dec 2020 03:28:13 +0530 Subject: [PATCH 06/24] fix the call for scheduler --- pytorch_lightning/trainer/evaluation_loop.py | 12 --------- pytorch_lightning/trainer/trainer.py | 8 ++++-- pytorch_lightning/trainer/training_loop.py | 27 +++++++++++--------- tests/callbacks/test_callbacks.py | 4 +-- tests/models/test_hooks.py | 4 +-- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 63f65bead2579..c72b188ac4a31 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -71,18 +71,6 @@ def get_evaluation_dataloaders(self, max_batches): return dataloaders, max_batches - def should_skip_evaluation(self, dataloaders, max_batches): - # skip when dataloaders aren't defined - if dataloaders is None: - return True - - # enable disabling validation step with limit_val_batches = 0 - should_skip = sum(max_batches) == 0 - if should_skip: - return True - - return False - def on_evaluation_start(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d9e3618b2af5b..e228749abef36 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -588,7 +588,7 @@ def train(self): # hook self.train_loop.on_train_end() - def run_evaluation(self, max_batches=None): + def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results self.logger_connector.set_stage(self.testing, reset=True) @@ -600,7 +600,7 @@ def run_evaluation(self, max_batches=None): dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) # check if we want to skip this evaluation - if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): + if sum(max_batches) == 0: return [], [] # ref model @@ -661,6 +661,10 @@ def run_evaluation(self, max_batches=None): # hook self.evaluation_loop.on_evaluation_epoch_end() + # update epoch-level lr_schedulers + if on_epoch: + self.optimizer_connector.update_learning_rates(interval='epoch') + # hook self.evaluation_loop.on_evaluation_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ae340451bca98..0cbab3c860f18 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -597,19 +597,13 @@ def run_training_epoch(self): self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(): + if self._num_training_batches_reached(is_last_batch): break # progress global step according to grads progress self.increment_accumulated_grad_global_step() # epoch end hook - should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) - if should_check_val: - self.trainer.run_evaluation(test_mode=False) - # reset stage to train - self.trainer.logger_connector.set_stage("train") - self.run_on_epoch_end_hook(epoch_output) # log epoch metrics @@ -620,10 +614,19 @@ def run_training_epoch(self): self.num_optimizers ) - # update LR schedulers - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + if should_check_val: + self.trainer.run_evaluation(test_mode=False, on_epoch=True) + # reset stage to train + self.trainer.logger_connector.set_stage("train") + + should_skip_eval = sum(self.trainer.num_val_batches) == 0 + should_train_only_check = not self.trainer.enable_validation and should_skip_eval + + if should_skip_eval or should_train_only_check: + # update epoch level lr_schedulers + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - should_train_only_check = not self.trainer.enable_validation and (sum(self.trainer.num_val_batches) == 0) self.check_checkpoint_callback(should_train_only_check) self.check_early_stopping_callback(should_train_only_check) @@ -861,8 +864,8 @@ def increment_accumulated_grad_global_step(self): def _accumulated_batches_reached(self): return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 - def _num_training_batches_reached(self): - return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches + def _num_training_batches_reached(self, is_last_batch=False): + return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c9baf0db6976d..d6421877e80ad 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -86,6 +86,8 @@ def test_trainer_callback_system(torch_save): call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), + call.on_epoch_end(trainer, model), + call.on_train_epoch_end(trainer, model, ANY), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), @@ -93,8 +95,6 @@ def test_trainer_callback_system(torch_save): call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), - call.on_epoch_end(trainer, model), - call.on_train_epoch_end(trainer, model, ANY), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5352e749c5e55..5653dc3903170 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -372,6 +372,8 @@ def on_test_model_train(self): 'on_after_backward', 'on_before_zero_grad', 'on_train_batch_end', + 'on_epoch_end', + 'on_train_epoch_end', 'on_validation_model_eval', 'on_validation_epoch_start', 'on_validation_batch_start', @@ -379,8 +381,6 @@ def on_test_model_train(self): 'on_validation_epoch_end', 'on_save_checkpoint', 'on_validation_model_train', - 'on_epoch_end', - 'on_train_epoch_end', 'on_train_end', 'on_fit_end', ] From 0cc0254aa022a9454ea2afcbf49212507f71f86e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 2 Jan 2021 18:28:43 +0530 Subject: [PATCH 07/24] use property --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0cbab3c860f18..38ec8c4251c30 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -185,7 +185,7 @@ def on_train_end(self): def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + callbacks = self.trainer.checkpoint_callbacks if is_last and any(cb.save_last for cb in callbacks): rank_zero_info("Saving latest checkpoint...") From 15e09b0fc631ef6656f80874f683de14897b1523 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 17 Jan 2021 16:18:44 +0530 Subject: [PATCH 08/24] pep --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5f812b5d924b0..e5c960b3c002b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,11 +20,11 @@ """ -from copy import deepcopy import numbers import os -from pathlib import Path import re +from copy import deepcopy +from pathlib import Path from typing import Any, Dict, Optional, Union import numpy as np diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e228749abef36..f107a8f40774b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,7 +18,6 @@ import warnings from pathlib import Path from typing import Dict, Iterable, List, Optional, Union -import warnings import torch from torch.utils.data import DataLoader From c51f9468672dfc93ee7da2f5dc1aecf115bc7a70 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 17 Jan 2021 16:41:57 +0530 Subject: [PATCH 09/24] correct rebase --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 38ec8c4251c30..6d85525f6d10e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -616,7 +616,7 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) if should_check_val: - self.trainer.run_evaluation(test_mode=False, on_epoch=True) + self.trainer.run_evaluation(on_epoch=True) # reset stage to train self.trainer.logger_connector.set_stage("train") From 2c8ed939530a2d833ab21da351888be6df7b3d3c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 17 Jan 2021 23:36:37 +0530 Subject: [PATCH 10/24] gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index d6ae2ef48ed01..0e51e29f9ae4e 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,6 @@ pytorch\ lightning test-reports/ wandb .forked/ + +# ctags +tags From 5879528998299e6dbbe029cc22b074a8895d542b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 24 Jan 2021 00:31:32 +0530 Subject: [PATCH 11/24] ref --- pytorch_lightning/trainer/evaluation_loop.py | 3 ++ pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 34 ++++++++------------ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c72b188ac4a31..ab3965db18ea4 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -71,6 +71,9 @@ def get_evaluation_dataloaders(self, max_batches): return dataloaders, max_batches + def should_skip_evaluation(self, max_batches): + return len(max_batches) == 0 + def on_evaluation_start(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f107a8f40774b..46944537358f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -599,7 +599,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False): dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) # check if we want to skip this evaluation - if sum(max_batches) == 0: + if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] # ref model diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6d85525f6d10e..a7dbb30a605b7 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -91,13 +91,10 @@ def num_optimizers(self): return num_optimizers def should_skip_training(self): - if self.trainer.current_epoch >= self.trainer.max_epochs: - return True - - if self.trainer.limit_train_batches == 0: - return True - - return False + return ( + self.trainer.current_epoch >= self.trainer.max_epochs + or self.trainer.limit_train_batches == 0 + ) def on_train_start(self): # clear cache before training @@ -620,15 +617,14 @@ def run_training_epoch(self): # reset stage to train self.trainer.logger_connector.set_stage("train") - should_skip_eval = sum(self.trainer.num_val_batches) == 0 - should_train_only_check = not self.trainer.enable_validation and should_skip_eval + should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) + should_train_only_check = self.trainer.disable_validation or should_skip_eval - if should_skip_eval or should_train_only_check: + if should_train_only_check: # update epoch level lr_schedulers self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - - self.check_checkpoint_callback(should_train_only_check) - self.check_early_stopping_callback(should_train_only_check) + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) # increment the global step once # progress global step according to grads progress @@ -883,19 +879,17 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): if on_epoch: should_check_val = ( - can_check_val - and ((is_val_check_batch and epoch_end_val_check) - or self.trainer.should_stop - or is_last_batch_for_infinite_dataset) + (is_val_check_batch and epoch_end_val_check) + or self.trainer.should_stop + or is_last_batch_for_infinite_dataset ) else: should_check_val = ( - can_check_val - and is_val_check_batch + is_val_check_batch and not epoch_end_val_check ) - return should_check_val + return should_check_val and can_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step From 2e6c601c85b57c2f5b9f1a003046486e089d38ae Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 24 Jan 2021 02:53:29 +0530 Subject: [PATCH 12/24] add tests --- tests/checkpointing/test_model_checkpoint.py | 91 +++++++++++++++++--- 1 file changed, 78 insertions(+), 13 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3de26ef1a6fb6..0ea39df3b6b76 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os import pickle import platform @@ -49,26 +50,90 @@ def validation_epoch_end(self, outputs): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.parametrize('save_top_k', [-1]) -def test_model_checkpoint_correct_score(tmpdir, save_top_k): - """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path""" - tutils.reset_seed() +@pytest.mark.parametrize( + "validation_step,val_dataloaders,monitor", + [ + ('base', "base", 'val_log'), + ('base', "base", 'train_log_epoch'), + (None, "base", 'train_log_epoch'), + ("base", None, 'train_log_epoch') + ], +) +def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor): + """ + Test that when a model checkpoint is saved, it saves with + the correct score appended to ckpt_path and checkpoint data + """ + max_epochs = 3 + limit_train_batches = 5 + limit_val_batches = 7 - model = LogInTwoMethods() + class CustomBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) + self.val_logs = torch.randn(max_epochs, limit_val_batches) + + def training_step(self, batch, batch_idx): + out = super().training_step(batch, batch_idx) + log_value = self.train_log_epochs[self.current_epoch, batch_idx] + self.log('train_log', log_value, on_epoch=True) + return out - filename = "{val_acc:.4f}-{epoch}" + def validation_step(self, batch, batch_idx): + out = super().validation_step(batch, batch_idx) + log_value = self.val_logs[self.current_epoch, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return out - checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor='val_acc', save_top_k=save_top_k) + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] - trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2) + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' + checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) + + model = CustomBoringModel() + + if validation_step is None: + model.validation_step = None + if val_dataloaders is None: + model.val_dataloaders = None + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint], + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + max_epochs=max_epochs, + progress_bar_refresh_rate=0, + ) trainer.fit(model) ckpt_files = list(Path(tmpdir).glob('*.ckpt')) - - metrics = trainer.dev_debugger.logged_metrics - expected_filenames = {f'val_acc={metric["val_acc"]:.4f}-epoch={metric["epoch"]}.ckpt' for metric in metrics} - for ckpt_file in ckpt_files: - assert os.path.basename(ckpt_file) in expected_filenames + scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + assert len(ckpt_files) == len(scores) == max_epochs + + for epoch in range(max_epochs): + score = scores[epoch] + expected_score = getattr(model, f'{monitor}s')[epoch].mean().item() + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' + assert math.isclose(score, expected_score, rel_tol=1e-4) + + chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) + assert chk['epoch'] == epoch + 1 + assert chk['global_step'] == limit_train_batches * (epoch + 1) + + mc_specific_data = chk['callbacks'][type(checkpoint)] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score + + lr_scheduler_specific_data = chk['lr_schedulers'][0] + assert lr_scheduler_specific_data['_step_count'] == epoch + 2 + assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1 ** (epoch + 1)) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) From b3d601f65cbd09cdba0e0afd677f48dc65cb6578 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Jan 2021 01:04:52 +0530 Subject: [PATCH 13/24] fix --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index ab3965db18ea4..ea759a1c29750 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -72,7 +72,7 @@ def get_evaluation_dataloaders(self, max_batches): return dataloaders, max_batches def should_skip_evaluation(self, max_batches): - return len(max_batches) == 0 + return sum(max_batches) == 0 def on_evaluation_start(self, *args, **kwargs): if self.trainer.testing: From d84996ab48cddfc71ac85b3caaac306668cababf Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Jan 2021 01:44:57 +0530 Subject: [PATCH 14/24] add early stopping test --- .../callbacks/model_checkpoint.py | 6 +-- .../connectors/checkpoint_connector.py | 6 +-- tests/callbacks/test_early_stopping.py | 39 +++++++++++++++++-- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6f10bc8bb63b8..e8481d671c8e7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -166,7 +166,7 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period - self.last_global_step_saved = -1 + self._last_global_step_saved = -1 self.prefix = prefix self.current_score = None self.best_k_models = {} @@ -231,7 +231,7 @@ def save_checkpoint(self, trainer, pl_module): or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check - or self.last_global_step_saved == global_step # already saved at the last step + or self._last_global_step_saved == global_step # already saved at the last step ): return @@ -239,7 +239,7 @@ def save_checkpoint(self, trainer, pl_module): self._validate_monitor_key(trainer) # track epoch when ckpt was last checked - self.last_global_step_saved = global_step + self._last_global_step_saved = global_step # what can be monitored monitor_candidates = self._monitor_candidates(trainer) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 34d572de84c51..79e8239ecff5d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -387,10 +387,10 @@ def save_checkpoint(self, filepath, weights_only: bool = False): filepath: write-target file's path weights_only: saving model weights only """ - # dump states as a checkpoint dictionary object - checkpoint = self.dump_checkpoint(weights_only) - if self.trainer.is_global_zero: + # dump states as a checkpoint dictionary object + checkpoint = self.dump_checkpoint(weights_only) + # write the checkpoint dictionary on the file if self.trainer.accelerator_backend: checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 6fa7e5f3567b6..c8099d11f6707 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -113,11 +113,9 @@ def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_ep class ModelOverrideValidationReturn(EvalModelTemplate): validation_return_values = torch.Tensor(loss_values) - count = 0 def validation_epoch_end(self, outputs): - loss = self.validation_return_values[self.count] - self.count += 1 + loss = self.validation_return_values[self.current_epoch] return {"test_val_loss": loss} model = ModelOverrideValidationReturn() @@ -133,6 +131,41 @@ def validation_epoch_end(self, outputs): assert trainer.current_epoch == expected_stop_epoch +@pytest.mark.parametrize('validation_step', ['base', None]) +@pytest.mark.parametrize( + "loss_values, patience, expected_stop_epoch", + [ + ([6, 5, 5, 5, 5, 5], 3, 4), + ([6, 5, 4, 4, 3, 3], 1, 3), + ([6, 5, 6, 5, 5, 5], 3, 4), + ], +) +def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch): + """Test to ensure that early stopping is not triggered before patience is exhausted.""" + + class ModelOverrideTrainReturn(EvalModelTemplate): + train_return_values = torch.Tensor(loss_values) + + def training_epoch_end(self, outputs): + loss = self.train_return_values[self.current_epoch] + self.log('train_loss', loss) + + model = ModelOverrideTrainReturn() + + if validation_step is None: + model.validation_step = None + + early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + num_sanity_val_steps=0, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + + def test_pickling(tmpdir): early_stopping = EarlyStopping() From 549eb898fc3e477db2bc91d19860a2072b03b64c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Jan 2021 01:46:17 +0530 Subject: [PATCH 15/24] trigger From 260d1f50100bdd374c879b26befda232e9ce08b3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Jan 2021 01:59:30 +0530 Subject: [PATCH 16/24] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87943bbbb4ab0..d21b5ac23d74c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620)) +- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208)) + + ## [1.1.5] - 2021-01-19 From 85af96896b9c3b9ca36b91a05bbd220ebf76a93b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Jan 2021 02:22:14 +0530 Subject: [PATCH 17/24] rev --- .../trainer/connectors/checkpoint_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 79e8239ecff5d..34d572de84c51 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -387,10 +387,10 @@ def save_checkpoint(self, filepath, weights_only: bool = False): filepath: write-target file's path weights_only: saving model weights only """ - if self.trainer.is_global_zero: - # dump states as a checkpoint dictionary object - checkpoint = self.dump_checkpoint(weights_only) + # dump states as a checkpoint dictionary object + checkpoint = self.dump_checkpoint(weights_only) + if self.trainer.is_global_zero: # write the checkpoint dictionary on the file if self.trainer.accelerator_backend: checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) From 740a07ec778daf2fd2f09f817648768acc8846b3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 27 Jan 2021 03:39:52 +0530 Subject: [PATCH 18/24] 1.3 --- tests/checkpointing/test_model_checkpoint.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 35d73a618fb5b..723e67075690b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math import logging +import math import os import pickle import platform import re from argparse import Namespace +from distutils.version import LooseVersion from pathlib import Path from unittest import mock from unittest.mock import Mock @@ -134,7 +135,8 @@ def configure_optimizers(self): lr_scheduler_specific_data = chk['lr_schedulers'][0] assert lr_scheduler_specific_data['_step_count'] == epoch + 2 - assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1 ** (epoch + 1)) + if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1 ** (epoch + 1)) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) From 465579b4a6e0e230d4fdbbb04e1cb2f12e1fdae8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 30 Jan 2021 00:24:38 +0530 Subject: [PATCH 19/24] log --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 34d572de84c51..6953ff29330d6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -400,6 +400,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False): if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( - 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' + 'Warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) From 40bf21be6ef5d87285bdc248b313ee17a5d679a6 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 1 Feb 2021 11:47:25 +0530 Subject: [PATCH 20/24] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/training_loop.py | 23 ++++++++++------------ 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ef700bb88966f..7c70ee817449a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -630,9 +630,9 @@ def run_training_epoch(self): self.trainer.logger_connector.set_stage("train") should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) - should_train_only_check = self.trainer.disable_validation or should_skip_eval + should_train_only = self.trainer.disable_validation or should_skip_eval - if should_train_only_check: + if should_train_only: # update epoch level lr_schedulers self.trainer.optimizer_connector.update_learning_rates(interval='epoch') self.check_checkpoint_callback(True) @@ -893,17 +893,14 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches - if on_epoch: - should_check_val = ( - (is_val_check_batch and epoch_end_val_check) - or self.trainer.should_stop - or is_last_batch_for_infinite_dataset - ) - else: - should_check_val = ( - is_val_check_batch - and not epoch_end_val_check - ) + should_check_val = ( + (is_val_check_batch and epoch_end_val_check) + or self.trainer.should_stop + or is_last_batch_for_infinite_dataset + ) if on_epoch else ( + is_val_check_batch + and not epoch_end_val_check + ) return should_check_val and can_check_val From 6feabacd622667dcdb3e4ec1e5ac0688b20f8820 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 5 Feb 2021 15:50:45 +0530 Subject: [PATCH 21/24] Update pytorch_lightning/trainer/training_loop.py --- pytorch_lightning/trainer/training_loop.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e9c45028237b6..0b030b1e39a85 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -91,10 +91,7 @@ def num_optimizers(self): return num_optimizers def should_skip_training(self): - return ( - self.trainer.current_epoch >= self.trainer.max_epochs - or self.trainer.limit_train_batches == 0 - ) + return self.trainer.current_epoch >= self.trainer.max_epochs or self.trainer.num_training_batches == 0 def on_train_start(self): # clear cache before training From cffe27fcf212a27604365bce14c3ffb8342c5183 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 5 Feb 2021 15:58:32 +0530 Subject: [PATCH 22/24] Update CHANGELOG.md --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index daaaa58eb0dd2..eeec734fdf817 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208)) -- Fixed `TensorBoardLogger` not closing `SummaryWriter` on `finalize` ([#5696](https://github.com/PyTorchLightning/pytorch-lightning/pull/5696)) - Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775)) From 42276adc81c4e320713d96a4e45b344d4057a6e2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 8 Feb 2021 08:51:02 +0100 Subject: [PATCH 23/24] Apply suggestions from code review --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5885dbee7db8..dcef8bf61c78d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - - Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208)) - - - Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775)) From bcdd9ed1ce4faca4c9e5b9fd4c3820558c2103c9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 8 Feb 2021 09:00:05 +0100 Subject: [PATCH 24/24] date --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dcef8bf61c78d..7934003915850 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.1.8] - 2021-02-06 +## [1.1.8] - 2021-02-08 ### Fixed