From 89f284d6fbafcf6aadac7abf1af08ee3fba39865 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 23 Mar 2021 12:06:24 -0700 Subject: [PATCH 01/39] Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 4 +++- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 ++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ea1efd6e15873..1383964cbc789 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9586344d8c0d9..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,11 +26,12 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group @@ -51,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 536c1323b0e6715fb5919196ea48b0fcddddcd66 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 01:17:20 -0700 Subject: [PATCH 02/39] checkpoint consolidation --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:37:52 -0700 Subject: [PATCH 03/39] Update ddp_spawn.py --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From bf70e431b3ce4893de804e0f3b5d59e79346d6d7 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:41:33 -0700 Subject: [PATCH 04/39] Update test_metric_result_integration.py --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From ea749068785bbad689a12066544893b1605f20c5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:42:16 -0700 Subject: [PATCH 05/39] Update test_results.py --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From a9aae99f6ed6e9388ecf1d8a7bd79966176a65af Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:04 -0700 Subject: [PATCH 06/39] Update utils.py --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): From 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:43 -0700 Subject: [PATCH 07/39] Update utils.py --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 0d23d75bc91e4e0b7805712e394cb093fac22841 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:44:18 -0700 Subject: [PATCH 08/39] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From ca6f98ba8ff835368ae3ef91e435e4d4f458c45b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:51:55 -0700 Subject: [PATCH 09/39] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:57:31 -0700 Subject: [PATCH 10/39] Update test_results.py --- tests/core/test_results.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..334062ae994a2 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,6 +30,7 @@ def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From 7635b4f47bcef43b9bbe677ad96a3bad135246a5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:10 -0700 Subject: [PATCH 11/39] Revert "Update test_results.py" This reverts commit 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1. --- tests/core/test_results.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 334062ae994a2..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,7 +30,6 @@ def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From d64f90cbc748de193a02237acd6ac686750b82d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:20 -0700 Subject: [PATCH 12/39] Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da789f9d04d2c967a65adf4fb026dc134b8, reversing changes made to 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From dcdcd29731061c919b15ab0b56669259817a81c4 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:36 -0700 Subject: [PATCH 13/39] Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 8651d54d79396eaaba16d7eb1e769a1e91d5702e Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:40 -0700 Subject: [PATCH 14/39] Revert "Update utils.py" This reverts commit 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 15f4b9e59bec52b07dddb55eeda4d9a68b8bd6d2 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:45 -0700 Subject: [PATCH 15/39] Revert "Update utils.py" This reverts commit a9aae99f6ed6e9388ecf1d8a7bd79966176a65af. --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): From 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:48 -0700 Subject: [PATCH 16/39] Revert "Update test_results.py" This reverts commit ea749068785bbad689a12066544893b1605f20c5. --- tests/core/test_results.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From 6c095b2370a2afe9d24918a5798ce1ebffed7e0d Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:52 -0700 Subject: [PATCH 17/39] Revert "Update test_metric_result_integration.py" This reverts commit bf70e431b3ce4893de804e0f3b5d59e79346d6d7. --- tests/core/test_metric_result_integration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From 8222dc98ead37d961a52b7366070aa10f66d92d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:55 -0700 Subject: [PATCH 18/39] Revert "Update ddp_spawn.py" This reverts commit f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..941025b36c0ac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:17:01 -0700 Subject: [PATCH 19/39] Revert "checkpoint consolidation" This reverts commit 536c1323b0e6715fb5919196ea48b0fcddddcd66. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 7a369f47e1a94d701fce48c994cc3f2da266dad0 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:19:37 -0700 Subject: [PATCH 20/39] Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac. --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From b4a0b9e9e1e0a08e50979facc2f0fc74187de2ee Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:32:43 -0700 Subject: [PATCH 21/39] Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f47e1a94d701fce48c994cc3f2da266dad0. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 0ce7e056ac47436bc727f91f8eed335fc736696c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:44 -0700 Subject: [PATCH 22/39] Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc98ead37d961a52b7366070aa10f66d92d1. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From fe9736d94bfcac3e084eec2d63e62351d3618175 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:49 -0700 Subject: [PATCH 23/39] Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2370a2afe9d24918a5798ce1ebffed7e0d. --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From c314ef6d30373c2c94fdeceef6ee7b9d961a48c9 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:56 -0700 Subject: [PATCH 24/39] Revert "Revert "Update test_results.py"" This reverts commit 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b. --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From c3feda03d7fbb25dcf1917e718209f98d0503327 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:05 -0700 Subject: [PATCH 25/39] Revert "Revert "Update utils.py"" This reverts commit 8651d54d79396eaaba16d7eb1e769a1e91d5702e. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From c759477a0a9462f812a880a8cee7c09b3f432520 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:13 -0700 Subject: [PATCH 26/39] Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29731061c919b15ab0b56669259817a81c4. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 4e67db2a1fed55e6d6e3fa09766aaa55c7995ca1 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 10:57:58 -0700 Subject: [PATCH 27/39] modify distributed environment to make test pass --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 ++- tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 3 +++ tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 2 +- 6 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..0b4b7680076a3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything +import numpy log = logging.getLogger(__name__) @@ -78,7 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 4211f0c1bb0440823866395b862bd50cf0518b29 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 11:56:23 -0700 Subject: [PATCH 28/39] consolidate training loop checkpoints v1 --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 33 +++++++++++++---- .../trainer/logging_/test_logger_connector.py | 1 + 9 files changed, 96 insertions(+), 38 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..c3048f1801a59 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -85,6 +85,10 @@ def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[ """Called when the train epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the val epoch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2291016cc40ce..cb51e23311854 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -114,12 +114,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -138,28 +132,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -555,15 +527,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..3d9e0845023d2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,13 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +670,13 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +832,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 0bf539830dfcb012dd39de8521a618175ad980f3 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 25 Mar 2021 00:16:19 -0700 Subject: [PATCH 29/39] consolidate training loop checkpoints v2 --- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/callbacks/early_stopping.py | 10 +-- .../callbacks/lambda_function.py | 6 +- .../callbacks/model_checkpoint.py | 38 +++++----- pytorch_lightning/trainer/callback_hook.py | 6 +- .../callback_hook_validator.py | 4 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/callbacks/test_lambda_function.py | 12 ++- .../test_checkpoint_callback_frequency.py | 4 +- tests/checkpointing/test_model_checkpoint.py | 76 +++++++++++++------ tests/loggers/test_tensorboard.py | 2 +- .../trainer/logging_/test_logger_connector.py | 2 +- 12 files changed, 97 insertions(+), 69 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index c3048f1801a59..9f446d2aca327 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -85,8 +85,8 @@ def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[ """Called when the train epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" + def on_train_epoch_without_validation_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch where validation is not enabled.""" pass def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..81483e6cec3b0 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,20 +143,14 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): + def on_train_epoch_without_validation_end(self, trainer, pl_module): from pytorch_lightning.trainer.states import TrainerState if ( trainer.state != TrainerState.FITTING or trainer.sanity_checking or not trainer.checkpoint_connector.has_trained ): return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) + self._run_early_stopping_check(trainer) def _run_early_stopping_check(self, trainer): """ diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..d802044185ca9 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, + on_train_epoch_without_validation_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end + if on_train_epoch_without_validation_end is not None: + self.on_train_epoch_without_validation_end = on_train_epoch_without_validation_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..cea6eafc444da 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -111,6 +111,7 @@ class ModelCheckpoint(Callback): This argument has been deprecated in v1.3 and will be removed in v1.5. Use ``every_n_val_epochs`` instead. + trigger_on_train_end: Whether to trigger the save_checkpoint at the end of training. Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -186,6 +187,7 @@ def __init__( every_n_train_steps: Optional[int] = None, every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, + trigger_on_train_end: bool = False, ): super().__init__() self.monitor = monitor @@ -206,7 +208,7 @@ def __init__( self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) - self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period, trigger_on_train_end) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -238,35 +240,33 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): + def on_train_epoch_without_validation_end(self, trainer, pl_module): """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled + at the end of each training epoch where validation is disabled """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) if ( self._should_skip_saving_checkpoint(trainer) or not trainer.checkpoint_connector.has_trained ): return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) + self.save_checkpoint(trainer) def on_train_end(self, trainer, *args, **kwargs) -> None: """ checkpoints can be saved at the end of the trianing """ + if not self._trigger_on_train_end: + return + # need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step trainer.global_step -= 1 if ( not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained ): if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) + rank_zero_info("Saving last checkpoint...") + self.save_checkpoint(trainer, is_on_train_end=True) trainer.global_step += 1 def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -282,7 +282,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None): + def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: bool = False): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -297,7 +297,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): global_step = trainer.global_step self._add_backward_monitor_support(trainer) - self._validate_monitor_key(trainer) + self._validate_monitor_key(trainer, is_on_train_end) # track epoch when ckpt was last checked self._last_global_step_saved = global_step @@ -387,7 +387,7 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] def __init_triggers( - self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int], trigger_on_train_end: bool, ) -> None: # Default to running once after each validation epoch if neither @@ -409,6 +409,7 @@ def __init_triggers( self._every_n_val_epochs = period self._period = self._every_n_val_epochs + self._trigger_on_train_end = trigger_on_train_end @property def period(self) -> Optional[int]: @@ -613,13 +614,14 @@ def _add_backward_monitor_support(self, trainer): " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning ) - def _validate_monitor_key(self, trainer): + def _validate_monitor_key(self, trainer, is_on_train_end: bool): metrics = trainer.logger_connector.callback_metrics # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics): + if self.monitor is not None and not self._is_valid_monitor_key(metrics) and not is_on_train_end: m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics " + "and it is not triggered on train end:" f" {list(metrics.keys())}. " f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" ) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..3e74592b9baa2 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,12 +92,12 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: + def on_train_epoch_without_validation_end(self) -> None: """ - Called when at the very end of train epoch. + Called when at the very end of train epoch where validation is not enabled. """ for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) + callback.on_train_epoch_without_validation_end(self, self.lightning_module) def on_validation_epoch_start(self): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..1fbedc71c4253 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -101,8 +101,8 @@ def _on_train_epoch_end_log(): return {"on_step": [False], "on_epoch": [False, True]} @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" + def _on_train_epoch_without_validation_end_log(): + """Called when at the very end of train epoch where validation is not enabled.""" return {"on_step": [False], "on_epoch": [False, True]} @staticmethod diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cb51e23311854..1ed5230264049 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -533,7 +533,7 @@ def run_training_epoch(self): self.trainer.training = True if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') + self.trainer.call_hook('on_train_epoch_without_validation_end') # increment the global step once # progress global step according to grads progress diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index c2edfb176f164..cb140c49c4f3c 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import pytest from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.helpers.boring_model import BoringModel - -def test_lambda_call(tmpdir): +@pytest.mark.parametrize('should_validate', [True, False]) +def test_lambda_call(tmpdir, should_validate: bool): seed_everything(42) class CustomModel(BoringModel): @@ -27,12 +28,15 @@ def on_train_epoch_start(self): if self.current_epoch > 1: raise KeyboardInterrupt + model = CustomModel() checker = set() hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint") - - model = CustomModel() + if not should_validate: + model.validation_step = None + else: + hooks.remove("on_train_epoch_without_validation_end") trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 7926bc46dd290..7f2c8d19984f0 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -50,7 +50,7 @@ def test_mc_called(tmpdir): @mock.patch('torch.save') @pytest.mark.parametrize( ['epochs', 'val_check_interval', 'expected'], - [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)], + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 6)], ) def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int): @@ -73,7 +73,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter (1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), - (2, 2, 0.3, 7), + (2, 2, 0.3, 6), ]) def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3d9e0845023d2..adad9833232e0 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -453,6 +453,7 @@ def test_model_checkpoint_file_extension(tmpdir): dirpath=tmpdir, save_top_k=1, save_last=True, + trigger_on_train_end=True, ) trainer = Trainer( default_root_dir=tmpdir, @@ -594,10 +595,17 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): @pytest.mark.parametrize("period", list(range(4))) -def test_model_checkpoint_period(tmpdir, period: int): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + period=period, + trigger_on_train_end=trigger_on_train_end, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -609,22 +617,28 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] if period > 0 else [] ) - expected.append(final_epoch_ckpt) + if trigger_on_train_end and (period == 0 or epochs % period != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger_on_train_end: bool): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=every_n_val_epochs, + trigger_on_train_end=trigger_on_train_end, ) trainer = Trainer( default_root_dir=tmpdir, @@ -637,18 +651,21 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else [] ) - expected.append(final_epoch_ckpt) + + if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs, trigger_on_train_end: bool): """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 @@ -657,7 +674,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc filename='{epoch}', save_top_k=-1, every_n_val_epochs=(2 * every_n_val_epochs), - period=every_n_val_epochs + period=every_n_val_epochs, + trigger_on_train_end=trigger_on_train_end, ) trainer = Trainer( default_root_dir=tmpdir, @@ -670,13 +688,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else [] ) - expected.append(final_epoch_ckpt) + if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -817,30 +836,39 @@ def test_default_checkpoint_behavior(tmpdir): @pytest.mark.parametrize('max_epochs', [1, 2]) +@pytest.mark.parametrize('every_n_val_epochs', [2, 3]) @pytest.mark.parametrize('should_validate', [True, False]) @pytest.mark.parametrize('save_last', [True, False]) @pytest.mark.parametrize('verbose', [True, False]) +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) + def test_model_checkpoint_save_last_warning( - tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool + tmpdir, caplog, max_epochs: int, every_n_val_epochs: int, should_validate: bool, save_last: bool, verbose: bool, trigger_on_train_end: bool, ): - """Tests 'Saving latest checkpoint...' log""" + """Tests 'Saving last checkpoint...' log""" model = LogInTwoMethods() if not should_validate: model.validation_step = None - ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose) + ckpt = ModelCheckpoint( + monitor='early_stop_on', + dirpath=tmpdir, + every_n_val_epochs=every_n_val_epochs, + save_top_k=0, + save_last=save_last, + verbose=verbose, + trigger_on_train_end=trigger_on_train_end, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + expected = False + if should_validate and save_last and verbose and trigger_on_train_end: + expected = (max_epochs % every_n_val_epochs != 0) + assert caplog.messages.count('Saving last checkpoint...') == expected def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 1a85270c6dcbb..d5e490f360eb4 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -55,7 +55,7 @@ def __init__(self, b1=0.5, b2=0.999): assert len(yaml_params.keys()) == 2 # verify artifacts - assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1 + assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 0 # verify tb logs event_acc = EventAccumulator(folder_path) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..d3fb38e118fb6 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', + 'on_train_epoch_without_validation_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From bbb5f8372302b257b023ed580e5c29a91e934eb5 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 25 Mar 2021 01:03:44 -0700 Subject: [PATCH 30/39] consolidate training loop checkpoints v3 --- .../callbacks/model_checkpoint.py | 6 +++++- .../plugins/training_type/ddp_spawn.py | 2 -- pytorch_lightning/trainer/training_loop.py | 2 -- tests/checkpointing/test_model_checkpoint.py | 18 ++++++++++++------ tests/core/test_metric_result_integration.py | 3 --- tests/core/test_results.py | 4 +--- tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 +-- 9 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cea6eafc444da..5d6469fd3595b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -387,7 +387,11 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] def __init_triggers( - self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int], trigger_on_train_end: bool, + self, + every_n_train_steps: Optional[int], + every_n_val_epochs: Optional[int], + period: Optional[int], + trigger_on_train_end: bool, ) -> None: # Default to running once after each validation epoch if neither diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0b4b7680076a3..15f936f882e8d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,7 +33,6 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -import numpy log = logging.getLogger(__name__) @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1ed5230264049..f369d46c3cee9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -19,14 +19,12 @@ import numpy as np import torch -from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing -from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index adad9833232e0..59e06cccc102f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -623,7 +623,7 @@ def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool else [] ) if trigger_on_train_end and (period == 0 or epochs % period != 0): - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -658,7 +658,7 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger ) if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -693,8 +693,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc if every_n_val_epochs > 0 else [] ) - if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -841,9 +841,15 @@ def test_default_checkpoint_behavior(tmpdir): @pytest.mark.parametrize('save_last', [True, False]) @pytest.mark.parametrize('verbose', [True, False]) @pytest.mark.parametrize('trigger_on_train_end', [False, True]) - def test_model_checkpoint_save_last_warning( - tmpdir, caplog, max_epochs: int, every_n_val_epochs: int, should_validate: bool, save_last: bool, verbose: bool, trigger_on_train_end: bool, + tmpdir, + caplog, + max_epochs: int, + every_n_val_epochs: int, + should_validate: bool, + save_last: bool, + verbose: bool, + trigger_on_train_end: bool, ): """Tests 'Saving last checkpoint...' log""" model = LogInTwoMethods() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..9586344d8c0d9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,12 +26,11 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group @@ -52,7 +51,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From db37add19d7381f6130c98d0e9651da6d94606f4 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 25 Mar 2021 02:51:00 -0700 Subject: [PATCH 31/39] consolidate training loop checkpoints v4 --- tests/callbacks/test_lambda_function.py | 1 + tests/trainer/connectors/test_callback_connector.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index cb140c49c4f3c..b6013e9bd67a4 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -18,6 +18,7 @@ from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.helpers.boring_model import BoringModel + @pytest.mark.parametrize('should_validate', [True, False]) def test_lambda_call(tmpdir, should_validate: bool): seed_everything(42) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..aba0e43e7b51d 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -57,7 +57,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): callback0 = StatefulCallback0() callback1 = StatefulCallback1() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states", trigger_on_train_end=True) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, From 51aefb8f147544ee5238d1b2fb8a5063019529bb Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 25 Mar 2021 17:04:54 -0700 Subject: [PATCH 32/39] consolidate training end model checkpoint --- .../callbacks/model_checkpoint.py | 43 ++++++++-- pytorch_lightning/trainer/training_loop.py | 9 -- .../test_checkpoint_callback_frequency.py | 4 +- tests/checkpointing/test_model_checkpoint.py | 85 +++++++++++++++---- tests/loggers/test_tensorboard.py | 2 +- .../connectors/test_callback_connector.py | 2 +- 6 files changed, 109 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..5b0fe5cddf33d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -111,6 +111,9 @@ class ModelCheckpoint(Callback): This argument has been deprecated in v1.3 and will be removed in v1.5. Use ``every_n_val_epochs`` instead. + trigger_on_train_end: Whether to trigger the save_checkpoint at the end of training. + By default, it is turned off. + Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -186,6 +189,7 @@ def __init__( every_n_train_steps: Optional[int] = None, every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, + trigger_on_train_end: bool = False, ): super().__init__() self.monitor = monitor @@ -206,7 +210,7 @@ def __init__( self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) - self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period, trigger_on_train_end) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -238,6 +242,24 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + if not self._trigger_on_train_end: + return + # need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving last checkpoint...") + self.save_checkpoint(trainer, is_on_train_end=True) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, @@ -251,7 +273,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None): + def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: bool = False): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -266,7 +288,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): global_step = trainer.global_step self._add_backward_monitor_support(trainer) - self._validate_monitor_key(trainer) + self._validate_monitor_key(trainer, is_on_train_end) # track epoch when ckpt was last checked self._last_global_step_saved = global_step @@ -356,7 +378,11 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] def __init_triggers( - self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + self, + every_n_train_steps: Optional[int], + every_n_val_epochs: Optional[int], + period: Optional[int], + trigger_on_train_end: bool = False, ) -> None: # Default to running once after each validation epoch if neither @@ -378,6 +404,7 @@ def __init_triggers( self._every_n_val_epochs = period self._period = self._every_n_val_epochs + self._trigger_on_train_end = trigger_on_train_end @property def period(self) -> Optional[int]: @@ -582,13 +609,13 @@ def _add_backward_monitor_support(self, trainer): " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning ) - def _validate_monitor_key(self, trainer): + def _validate_monitor_key(self, trainer, is_on_train_end: bool): metrics = trainer.logger_connector.callback_metrics - # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics): + if self.monitor is not None and not self._is_valid_monitor_key(metrics) and not is_on_train_end: m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics " + "and it is not triggered on train end:" f" {list(metrics.keys())}. " f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" ) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2291016cc40ce..4bb09d7c12d4d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -114,12 +114,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -143,9 +137,6 @@ def check_checkpoint_callback(self, should_update, is_last=False): if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - model = self.trainer.lightning_module for cb in callbacks: diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 7926bc46dd290..7f2c8d19984f0 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -50,7 +50,7 @@ def test_mc_called(tmpdir): @mock.patch('torch.save') @pytest.mark.parametrize( ['epochs', 'val_check_interval', 'expected'], - [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)], + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 6)], ) def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int): @@ -73,7 +73,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter (1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), - (2, 2, 0.3, 7), + (2, 2, 0.3, 6), ]) def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..6d4d44916dbc2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -453,6 +453,7 @@ def test_model_checkpoint_file_extension(tmpdir): dirpath=tmpdir, save_top_k=1, save_last=True, + trigger_on_train_end=True, ) trainer = Trainer( default_root_dir=tmpdir, @@ -594,10 +595,17 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): @pytest.mark.parametrize("period", list(range(4))) -def test_model_checkpoint_period(tmpdir, period: int): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + period=period, + trigger_on_train_end=trigger_on_train_end, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -609,16 +617,28 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] + if period > 0 + else [] + ) + if trigger_on_train_end and (period == 0 or epochs % period != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger_on_train_end: bool): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=every_n_val_epochs, + trigger_on_train_end=trigger_on_train_end, ) trainer = Trainer( default_root_dir=tmpdir, @@ -631,13 +651,21 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] + if every_n_val_epochs > 0 + else [] + ) + + if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs, trigger_on_train_end: bool): """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 @@ -646,7 +674,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc filename='{epoch}', save_top_k=-1, every_n_val_epochs=(2 * every_n_val_epochs), - period=every_n_val_epochs + period=every_n_val_epochs, + trigger_on_train_end=trigger_on_train_end, ) trainer = Trainer( default_root_dir=tmpdir, @@ -659,8 +688,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] + if every_n_val_epochs > 0 + else [] + ) + if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -801,17 +836,34 @@ def test_default_checkpoint_behavior(tmpdir): @pytest.mark.parametrize('max_epochs', [1, 2]) +@pytest.mark.parametrize('every_n_val_epochs', [2, 3]) @pytest.mark.parametrize('should_validate', [True, False]) @pytest.mark.parametrize('save_last', [True, False]) @pytest.mark.parametrize('verbose', [True, False]) +@pytest.mark.parametrize('trigger_on_train_end', [False, True]) def test_model_checkpoint_save_last_warning( - tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool + tmpdir, + caplog, + max_epochs: int, + every_n_val_epochs: int, + should_validate: bool, + save_last: bool, + verbose: bool, + trigger_on_train_end: bool, ): - """Tests 'Saving latest checkpoint...' log""" + """Tests 'Saving last checkpoint...' log""" model = LogInTwoMethods() if not should_validate: model.validation_step = None - ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose) + ckpt = ModelCheckpoint( + monitor='early_stop_on', + dirpath=tmpdir, + every_n_val_epochs=every_n_val_epochs, + save_top_k=0, + save_last=save_last, + verbose=verbose, + trigger_on_train_end=trigger_on_train_end, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[ckpt], @@ -819,7 +871,10 @@ def test_model_checkpoint_save_last_warning( ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + expected = False + if save_last and verbose and trigger_on_train_end: + expected = (max_epochs % every_n_val_epochs != 0) + assert caplog.messages.count('Saving last checkpoint...') == expected def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 1a85270c6dcbb..d5e490f360eb4 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -55,7 +55,7 @@ def __init__(self, b1=0.5, b2=0.999): assert len(yaml_params.keys()) == 2 # verify artifacts - assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1 + assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 0 # verify tb logs event_acc = EventAccumulator(folder_path) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..aba0e43e7b51d 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -57,7 +57,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): callback0 = StatefulCallback0() callback1 = StatefulCallback1() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states", trigger_on_train_end=True) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, From e013b19d8d9fa0740ebf926c794a0c81d879a10f Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 25 Mar 2021 18:38:49 -0700 Subject: [PATCH 33/39] remove distributed environment hack --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- tests/core/test_metric_result_integration.py | 3 --- tests/core/test_results.py | 3 --- tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 +-- 6 files changed, 3 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0b4b7680076a3..15f936f882e8d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,7 +33,6 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -import numpy log = logging.getLogger(__name__) @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 4e5ad9453ef7f..d67c9473bbb2e 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From d90cd64429985b9cd21da8a82723f78e11f37072 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 25 Mar 2021 19:14:00 -0700 Subject: [PATCH 34/39] consolidate on_train_end only --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 9 ------ .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 11 ------- .../plugins/training_type/ddp_spawn.py | 1 + pytorch_lightning/trainer/callback_hook.py | 7 ----- .../callback_hook_validator.py | 5 ---- pytorch_lightning/trainer/training_loop.py | 30 +++++++++++++++++-- tests/callbacks/test_lambda_function.py | 11 ++----- tests/checkpointing/test_model_checkpoint.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 11 files changed, 32 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 6316e74f9c4fd..7757902bd3baf 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -85,10 +85,6 @@ def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[ """Called when the train epoch ends.""" pass - def on_train_epoch_without_validation_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch where validation is not enabled.""" - pass - def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the val epoch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 81483e6cec3b0..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,15 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_without_validation_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index d802044185ca9..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_without_validation_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_without_validation_end is not None: - self.on_train_epoch_without_validation_end = on_train_epoch_without_validation_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4dd7a149a381f..27975b142fc39 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -240,17 +240,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_without_validation_end(self, trainer, pl_module): - """ - at the end of each training epoch where validation is disabled - """ - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - self.save_checkpoint(trainer) - def on_train_end(self, trainer, *args, **kwargs) -> None: """ checkpoints can be saved at the end of the trianing diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 15f936f882e8d..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -78,6 +78,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index ca11d9a824971..6d434e12a2e78 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_without_validation_end(self) -> None: - """ - Called when at the very end of train epoch where validation is not enabled. - """ - for callback in self.callbacks: - callback.on_train_epoch_without_validation_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 1fbedc71c4253..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_without_validation_end_log(): - """Called when at the very end of train epoch where validation is not enabled.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8b4b37ae6198d..61455b9078714 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -19,6 +19,7 @@ import numpy as np import torch +from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin @@ -130,6 +131,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -525,14 +548,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_without_validation_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index b6013e9bd67a4..c2edfb176f164 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import pytest from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.helpers.boring_model import BoringModel -@pytest.mark.parametrize('should_validate', [True, False]) -def test_lambda_call(tmpdir, should_validate: bool): +def test_lambda_call(tmpdir): seed_everything(42) class CustomModel(BoringModel): @@ -29,15 +27,12 @@ def on_train_epoch_start(self): if self.current_epoch > 1: raise KeyboardInterrupt - model = CustomModel() checker = set() hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint") - if not should_validate: - model.validation_step = None - else: - hooks.remove("on_train_epoch_without_validation_end") + + model = CustomModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 59e06cccc102f..6d4d44916dbc2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -872,7 +872,7 @@ def test_model_checkpoint_save_last_warning( with caplog.at_level(logging.INFO): trainer.fit(model) expected = False - if should_validate and save_last and verbose and trigger_on_train_end: + if save_last and verbose and trigger_on_train_end: expected = (max_epochs % every_n_val_epochs != 0) assert caplog.messages.count('Saving last checkpoint...') == expected diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 7b0f6518fd4d2..d14ed71940328 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_without_validation_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From ccd771fa9a2f87b27a10ffd345375d5a17b14413 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Fri, 9 Apr 2021 22:09:08 -0700 Subject: [PATCH 35/39] rebase --- .../callbacks/model_checkpoint.py | 38 ++++++++++++------- tests/checkpointing/test_model_checkpoint.py | 2 +- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4675c9445daf0..c38163a914735 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -28,9 +28,13 @@ import numpy as np import torch import yaml - from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import ( + rank_zero_deprecation, + rank_zero_info, + rank_zero_only, + rank_zero_warn, +) from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -223,7 +227,7 @@ def on_train_batch_end( self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ - if self._should_skip_saving_checkpoint(trainer): + if self._should_skip_saving_checkpoint(trainer, is_on_train_end=False): return step = trainer.global_step skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) @@ -236,7 +240,8 @@ def on_validation_end(self, trainer, pl_module) -> None: checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + self._should_skip_saving_checkpoint(trainer, is_on_train_end=False) + or self._every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 ) if skip: @@ -249,17 +254,13 @@ def on_train_end(self, trainer, *args, **kwargs) -> None: """ if not self._trigger_on_train_end: return - # need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - trainer.global_step -= 1 if ( - not self._should_skip_saving_checkpoint(trainer) + not self._should_skip_saving_checkpoint(trainer, is_on_train_end=True) and trainer.checkpoint_connector.has_trained ): if self.save_last and self.verbose: rank_zero_info("Saving last checkpoint...") self.save_checkpoint(trainer, is_on_train_end=True) - trainer.global_step += 1 def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -274,7 +275,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: bool = False): + def save_checkpoint( + self, trainer, unused: Optional = None, is_on_train_end: bool = False + ): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -306,13 +309,22 @@ def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: boo # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) - def _should_skip_saving_checkpoint(self, trainer) -> bool: + def _should_skip_saving_checkpoint(self, trainer, is_on_train_end: bool) -> bool: from pytorch_lightning.trainer.states import TrainerState + + if is_on_train_end: + # as we advance one step at end of training, we use global_step - 1 + # to avoid saving duplicates + is_last_saved = self._last_global_step_saved == trainer.global_step - 1 + else: + is_last_saved = self._last_global_step_saved == trainer.global_step + return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.state + != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check - or self._last_global_step_saved == trainer.global_step # already saved at the last step + or is_last_saved # already saved at the last step ) def __validate_init_configuration(self): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7f72613555cea..8f4f959a9a6c3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -462,7 +462,7 @@ def test_model_checkpoint_file_extension(tmpdir): ) trainer.fit(model) - expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + expected = ['epoch=0-step=1.tpkc', 'last.tpkc'] assert set(expected) == set(os.listdir(tmpdir)) From 70ebc9ff68b0fab343373e056e8b4e6d1e3ff4a4 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Fri, 9 Apr 2021 22:58:54 -0700 Subject: [PATCH 36/39] modify --- .../callbacks/model_checkpoint.py | 40 ++++++------------- tests/checkpointing/test_model_checkpoint.py | 22 +++------- 2 files changed, 19 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index c38163a914735..86056ac051706 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -28,13 +28,9 @@ import numpy as np import torch import yaml + from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import ( - rank_zero_deprecation, - rank_zero_info, - rank_zero_only, - rank_zero_warn, -) +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -227,7 +223,7 @@ def on_train_batch_end( self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ - if self._should_skip_saving_checkpoint(trainer, is_on_train_end=False): + if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) @@ -240,8 +236,7 @@ def on_validation_end(self, trainer, pl_module) -> None: checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer, is_on_train_end=False) - or self._every_n_val_epochs < 1 + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 ) if skip: @@ -254,13 +249,14 @@ def on_train_end(self, trainer, *args, **kwargs) -> None: """ if not self._trigger_on_train_end: return - if ( - not self._should_skip_saving_checkpoint(trainer, is_on_train_end=True) - and trainer.checkpoint_connector.has_trained - ): + # as we advance one step at end of training, we use global_step - 1 + # to avoid saving duplicates + trainer.global_step -= 1 + if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained): if self.save_last and self.verbose: rank_zero_info("Saving last checkpoint...") self.save_checkpoint(trainer, is_on_train_end=True) + trainer.global_step += 1 def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -275,9 +271,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint( - self, trainer, unused: Optional = None, is_on_train_end: bool = False - ): + def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: bool = False): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -309,22 +303,14 @@ def save_checkpoint( # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) - def _should_skip_saving_checkpoint(self, trainer, is_on_train_end: bool) -> bool: + def _should_skip_saving_checkpoint(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState - if is_on_train_end: - # as we advance one step at end of training, we use global_step - 1 - # to avoid saving duplicates - is_last_saved = self._last_global_step_saved == trainer.global_step - 1 - else: - is_last_saved = self._last_global_step_saved == trainer.global_step - return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state - != TrainerState.FITTING # don't save anything during non-fit + or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check - or is_last_saved # already saved at the last step + or self._last_global_step_saved == trainer.global_step # already saved at the last step ) def __validate_init_configuration(self): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 8f4f959a9a6c3..7d3826d520b95 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -462,7 +462,7 @@ def test_model_checkpoint_file_extension(tmpdir): ) trainer.fit(model) - expected = ['epoch=0-step=1.tpkc', 'last.tpkc'] + expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] assert set(expected) == set(os.listdir(tmpdir)) @@ -616,11 +616,7 @@ def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool trainer.fit(model) # check that the correct ckpts were created - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] - if period > 0 - else [] - ) + expected = ([f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] if period > 0 else []) if trigger_on_train_end and (period == 0 or epochs % period != 0): final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) expected.append(final_epoch_ckpt) @@ -650,11 +646,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger trainer.fit(model) # check that the correct ckpts were created - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] - if every_n_val_epochs > 0 - else [] - ) + expected = ([f"epoch={e}.ckpt" for e in range(epochs) + if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else []) if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) @@ -687,11 +680,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] - if every_n_val_epochs > 0 - else [] - ) + expected = ([f"epoch={e}.ckpt" for e in range(epochs) + if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else []) if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) expected.append(final_epoch_ckpt) From ddf76c4f3ad5c977b890e38648a352058d6d497d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Sat, 10 Apr 2021 00:23:33 -0700 Subject: [PATCH 37/39] add one more unittest for end of training with invalid monitor --- .../callbacks/model_checkpoint.py | 15 ++++++++---- tests/checkpointing/test_model_checkpoint.py | 24 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 86056ac051706..7a7f05951f35b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -301,7 +301,16 @@ def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: boo # Mode 2: save monitor=None checkpoints self._save_none_monitor_checkpoint(trainer, monitor_candidates) # Mode 3: save last checkpoints - self._save_last_checkpoint(trainer, monitor_candidates) + if self._should_save_last_checkpoint(trainer, monitor_candidates, is_on_train_end): + self._save_last_checkpoint(trainer, monitor_candidates) + + def _should_save_last_checkpoint(self, trainer, monitor_candidates, is_on_train_end) -> bool: + # we should save last checkpoint if save_last is set or + # at the end of the training, we fall back to save last checkpoint if + # we set monitor value but not existent in monitor_candidates + return self.save_last or ( + is_on_train_end and self.monitor is not None and monitor_candidates.get(self.monitor) is None + ) def _should_skip_saving_checkpoint(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState @@ -381,7 +390,7 @@ def __init_triggers( every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int], - trigger_on_train_end: bool = False, + trigger_on_train_end: bool, ) -> None: # Default to running once after each validation epoch if neither @@ -643,8 +652,6 @@ def _monitor_candidates(self, trainer): return monitor_candidates def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): - if not self.save_last: - return filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7d3826d520b95..783180cd96fb1 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -819,6 +819,30 @@ def test_default_checkpoint_behavior(tmpdir): assert ckpts[0] == 'epoch=2-step=14.ckpt' +def test_ckpt_on_train_end_with_invalid_monitor(tmpdir): + """ Tests that the checkpoints are saved at end of training with invalid monitor.""" + + model = LogInTwoMethods() + model_cpt = ModelCheckpoint( + filename="{epoch}", + dirpath=tmpdir, + every_n_val_epochs=2, + monitor="invalid", # monitor is invalid, save_last is not set + trigger_on_train_end=True, + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + progress_bar_refresh_rate=0, + callbacks=[model_cpt], + logger=False, + ) + trainer.fit(model) + # fall back to save last + expected = ['last.ckpt'] + assert set(expected) == set(os.listdir(tmpdir)) + + @pytest.mark.parametrize('max_epochs', [1, 2]) @pytest.mark.parametrize('every_n_val_epochs', [2, 3]) @pytest.mark.parametrize('should_validate', [True, False]) From 48a34e8700e4da929fe1b9779d1fca1a4aabe62d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Sat, 10 Apr 2021 00:29:04 -0700 Subject: [PATCH 38/39] add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e4ca17b5922c..27fda3f8a8af4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) +- Fixed model checkpointing at end of training ([#6671](https://github.com/PyTorchLightning/pytorch-lightning/pull/6671)) + + ## [1.2.7] - 2021-04-06 ### Fixed From f9616a3838f6a8cdae6f943b70d578c464a4ff98 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 13 Apr 2021 02:01:35 -0700 Subject: [PATCH 39/39] comments, call _save_last_checkpoint directly for train end --- .../callbacks/model_checkpoint.py | 28 ++++++------- tests/checkpointing/test_model_checkpoint.py | 39 +++++++++++-------- .../connectors/test_callback_connector.py | 4 +- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7a7f05951f35b..7b0bc860192bb 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -112,7 +112,7 @@ class ModelCheckpoint(Callback): Use ``every_n_val_epochs`` instead. trigger_on_train_end: Whether to trigger the save_checkpoint at the end of training. - By default, it is turned off. + By default, it is turned off. If it is turned on, the model will be saved to file `last.ckpt`. Note: @@ -255,7 +255,8 @@ def on_train_end(self, trainer, *args, **kwargs) -> None: if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained): if self.save_last and self.verbose: rank_zero_info("Saving last checkpoint...") - self.save_checkpoint(trainer, is_on_train_end=True) + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) trainer.global_step += 1 def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -271,7 +272,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: bool = False): + def save_checkpoint(self, trainer, unused: Optional = None): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -286,7 +287,7 @@ def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: boo global_step = trainer.global_step self._add_backward_monitor_support(trainer) - self._validate_monitor_key(trainer, is_on_train_end) + self._validate_monitor_key(trainer) # track epoch when ckpt was last checked self._last_global_step_saved = global_step @@ -301,16 +302,7 @@ def save_checkpoint(self, trainer, unused: Optional = None, is_on_train_end: boo # Mode 2: save monitor=None checkpoints self._save_none_monitor_checkpoint(trainer, monitor_candidates) # Mode 3: save last checkpoints - if self._should_save_last_checkpoint(trainer, monitor_candidates, is_on_train_end): - self._save_last_checkpoint(trainer, monitor_candidates) - - def _should_save_last_checkpoint(self, trainer, monitor_candidates, is_on_train_end) -> bool: - # we should save last checkpoint if save_last is set or - # at the end of the training, we fall back to save last checkpoint if - # we set monitor value but not existent in monitor_candidates - return self.save_last or ( - is_on_train_end and self.monitor is not None and monitor_candidates.get(self.monitor) is None - ) + self._save_last_checkpoint(trainer, monitor_candidates) def _should_skip_saving_checkpoint(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState @@ -617,13 +609,12 @@ def _add_backward_monitor_support(self, trainer): " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning ) - def _validate_monitor_key(self, trainer, is_on_train_end: bool): + def _validate_monitor_key(self, trainer): metrics = trainer.logger_connector.callback_metrics # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics) and not is_on_train_end: + if self.monitor is not None and not self._is_valid_monitor_key(metrics): m = ( f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics " - "and it is not triggered on train end:" f" {list(metrics.keys())}. " f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?" ) @@ -653,6 +644,9 @@ def _monitor_candidates(self, trainer): def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if not self.save_last: + return + filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch, diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 783180cd96fb1..ec71c57e30389 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -461,8 +461,7 @@ def test_model_checkpoint_file_extension(tmpdir): logger=False, ) trainer.fit(model) - - expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + expected = ['last.tpkc'] assert set(expected) == set(os.listdir(tmpdir)) @@ -595,13 +594,15 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): @pytest.mark.parametrize("period", list(range(4))) @pytest.mark.parametrize('trigger_on_train_end', [False, True]) -def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool): +@pytest.mark.parametrize('save_last', [False, True]) +def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool, save_last: bool): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( dirpath=tmpdir, filename='{epoch}', save_top_k=-1, + save_last=save_last, period=period, trigger_on_train_end=trigger_on_train_end, ) @@ -617,21 +618,22 @@ def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool # check that the correct ckpts were created expected = ([f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] if period > 0 else []) - if trigger_on_train_end and (period == 0 or epochs % period != 0): - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) - expected.append(final_epoch_ckpt) + if save_last and (period > 0 or trigger_on_train_end): + expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) @pytest.mark.parametrize('trigger_on_train_end', [False, True]) -def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger_on_train_end: bool): +@pytest.mark.parametrize('save_last', [False, True]) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger_on_train_end: bool, save_last: bool): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( dirpath=tmpdir, filename='{epoch}', save_top_k=-1, + save_last=save_last, every_n_val_epochs=every_n_val_epochs, trigger_on_train_end=trigger_on_train_end, ) @@ -649,15 +651,17 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger expected = ([f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else []) - if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) - expected.append(final_epoch_ckpt) + if save_last and (every_n_val_epochs > 0 or trigger_on_train_end): + expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_val_epochs", list(range(4))) @pytest.mark.parametrize('trigger_on_train_end', [False, True]) -def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs, trigger_on_train_end: bool): +@pytest.mark.parametrize('save_last', [False, True]) +def test_model_checkpoint_every_n_val_epochs_and_period( + tmpdir, every_n_val_epochs, trigger_on_train_end: bool, save_last: bool +): """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 @@ -665,6 +669,7 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc dirpath=tmpdir, filename='{epoch}', save_top_k=-1, + save_last=save_last, every_n_val_epochs=(2 * every_n_val_epochs), period=every_n_val_epochs, trigger_on_train_end=trigger_on_train_end, @@ -682,9 +687,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc # check that the correct ckpts were created expected = ([f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else []) - if trigger_on_train_end and (every_n_val_epochs == 0 or epochs % every_n_val_epochs != 0): - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs - 1) - expected.append(final_epoch_ckpt) + if save_last and (every_n_val_epochs > 0 or trigger_on_train_end): + expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @@ -819,7 +823,8 @@ def test_default_checkpoint_behavior(tmpdir): assert ckpts[0] == 'epoch=2-step=14.ckpt' -def test_ckpt_on_train_end_with_invalid_monitor(tmpdir): +@pytest.mark.parametrize('save_last', [False, True]) +def test_ckpt_on_train_end_with_invalid_monitor(tmpdir, save_last: bool): """ Tests that the checkpoints are saved at end of training with invalid monitor.""" model = LogInTwoMethods() @@ -828,6 +833,7 @@ def test_ckpt_on_train_end_with_invalid_monitor(tmpdir): dirpath=tmpdir, every_n_val_epochs=2, monitor="invalid", # monitor is invalid, save_last is not set + save_last=save_last, trigger_on_train_end=True, ) trainer = Trainer( @@ -838,8 +844,7 @@ def test_ckpt_on_train_end_with_invalid_monitor(tmpdir): logger=False, ) trainer.fit(model) - # fall back to save last - expected = ['last.ckpt'] + expected = ['last.ckpt'] if save_last else [] assert set(expected) == set(os.listdir(tmpdir)) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index aba0e43e7b51d..1ae96f8e4bd9f 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -57,7 +57,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): callback0 = StatefulCallback0() callback1 = StatefulCallback1() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states", trigger_on_train_end=True) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True, trigger_on_train_end=True) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -67,7 +67,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): ) trainer.fit(model) - ckpt = torch.load(str(tmpdir / "all_states.ckpt")) + ckpt = torch.load(str(tmpdir / "last.ckpt")) state0 = ckpt["callbacks"][type(callback0)] state1 = ckpt["callbacks"][type(callback1)] assert "content0" in state0 and state0["content0"] == 0