From 89f284d6fbafcf6aadac7abf1af08ee3fba39865 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 23 Mar 2021 12:06:24 -0700 Subject: [PATCH 01/40] Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 4 +++- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 ++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ea1efd6e15873..1383964cbc789 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9586344d8c0d9..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,11 +26,12 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group @@ -51,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 536c1323b0e6715fb5919196ea48b0fcddddcd66 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 01:17:20 -0700 Subject: [PATCH 02/40] checkpoint consolidation --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:37:52 -0700 Subject: [PATCH 03/40] Update ddp_spawn.py --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From bf70e431b3ce4893de804e0f3b5d59e79346d6d7 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:41:33 -0700 Subject: [PATCH 04/40] Update test_metric_result_integration.py --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From ea749068785bbad689a12066544893b1605f20c5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:42:16 -0700 Subject: [PATCH 05/40] Update test_results.py --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From a9aae99f6ed6e9388ecf1d8a7bd79966176a65af Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:04 -0700 Subject: [PATCH 06/40] Update utils.py --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): From 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:43 -0700 Subject: [PATCH 07/40] Update utils.py --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 0d23d75bc91e4e0b7805712e394cb093fac22841 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:44:18 -0700 Subject: [PATCH 08/40] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From ca6f98ba8ff835368ae3ef91e435e4d4f458c45b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:51:55 -0700 Subject: [PATCH 09/40] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:57:31 -0700 Subject: [PATCH 10/40] Update test_results.py --- tests/core/test_results.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..334062ae994a2 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,6 +30,7 @@ def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From 7635b4f47bcef43b9bbe677ad96a3bad135246a5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:10 -0700 Subject: [PATCH 11/40] Revert "Update test_results.py" This reverts commit 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1. --- tests/core/test_results.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 334062ae994a2..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,7 +30,6 @@ def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From d64f90cbc748de193a02237acd6ac686750b82d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:20 -0700 Subject: [PATCH 12/40] Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da789f9d04d2c967a65adf4fb026dc134b8, reversing changes made to 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From dcdcd29731061c919b15ab0b56669259817a81c4 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:36 -0700 Subject: [PATCH 13/40] Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 8651d54d79396eaaba16d7eb1e769a1e91d5702e Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:40 -0700 Subject: [PATCH 14/40] Revert "Update utils.py" This reverts commit 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 15f4b9e59bec52b07dddb55eeda4d9a68b8bd6d2 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:45 -0700 Subject: [PATCH 15/40] Revert "Update utils.py" This reverts commit a9aae99f6ed6e9388ecf1d8a7bd79966176a65af. --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): From 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:48 -0700 Subject: [PATCH 16/40] Revert "Update test_results.py" This reverts commit ea749068785bbad689a12066544893b1605f20c5. --- tests/core/test_results.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From 6c095b2370a2afe9d24918a5798ce1ebffed7e0d Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:52 -0700 Subject: [PATCH 17/40] Revert "Update test_metric_result_integration.py" This reverts commit bf70e431b3ce4893de804e0f3b5d59e79346d6d7. --- tests/core/test_metric_result_integration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From 8222dc98ead37d961a52b7366070aa10f66d92d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:55 -0700 Subject: [PATCH 18/40] Revert "Update ddp_spawn.py" This reverts commit f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..941025b36c0ac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:17:01 -0700 Subject: [PATCH 19/40] Revert "checkpoint consolidation" This reverts commit 536c1323b0e6715fb5919196ea48b0fcddddcd66. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 7a369f47e1a94d701fce48c994cc3f2da266dad0 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:19:37 -0700 Subject: [PATCH 20/40] Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac. --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From b4a0b9e9e1e0a08e50979facc2f0fc74187de2ee Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:32:43 -0700 Subject: [PATCH 21/40] Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f47e1a94d701fce48c994cc3f2da266dad0. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 0ce7e056ac47436bc727f91f8eed335fc736696c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:44 -0700 Subject: [PATCH 22/40] Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc98ead37d961a52b7366070aa10f66d92d1. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From fe9736d94bfcac3e084eec2d63e62351d3618175 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:49 -0700 Subject: [PATCH 23/40] Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2370a2afe9d24918a5798ce1ebffed7e0d. --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From c314ef6d30373c2c94fdeceef6ee7b9d961a48c9 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:56 -0700 Subject: [PATCH 24/40] Revert "Revert "Update test_results.py"" This reverts commit 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b. --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From c3feda03d7fbb25dcf1917e718209f98d0503327 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:05 -0700 Subject: [PATCH 25/40] Revert "Revert "Update utils.py"" This reverts commit 8651d54d79396eaaba16d7eb1e769a1e91d5702e. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From c759477a0a9462f812a880a8cee7c09b3f432520 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:13 -0700 Subject: [PATCH 26/40] Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29731061c919b15ab0b56669259817a81c4. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 4e67db2a1fed55e6d6e3fa09766aaa55c7995ca1 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 10:57:58 -0700 Subject: [PATCH 27/40] modify distributed environment to make test pass --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 ++- tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 3 +++ tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 2 +- 6 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..0b4b7680076a3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything +import numpy log = logging.getLogger(__name__) @@ -78,7 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From cb39c74669de95384c61e0be23c977886961fbf2 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 5 Apr 2021 18:52:55 -0700 Subject: [PATCH 28/40] remove fragile error handling --- pytorch_lightning/trainer/trainer.py | 13 +++---- .../optimization/test_manual_optimization.py | 39 ++++++------------- .../optimization/test_multiple_optimizers.py | 5 ++- 3 files changed, 19 insertions(+), 38 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 27dcd6fe9aa0d..468f716d0455b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -599,6 +599,7 @@ def run_train(self) -> None: self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: + self.train_loop.on_train_end() return # early stopping @@ -607,6 +608,7 @@ def run_train(self) -> None: if self.should_stop: if met_min_epochs and met_min_steps: + self.train_loop.on_train_end() return else: log.info( @@ -614,7 +616,6 @@ def run_train(self) -> None: f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...' ) - # hook self.train_loop.on_train_end() @@ -624,14 +625,10 @@ def run_train(self) -> None: if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() - except (RuntimeError, AssertionError): - # if an exception is raised, the finally block is executed and can hide the actual exception - # that was initially raised if `on_train_end` also raises an exception. we want to avoid that - # for assertions and other runtime errors so we aren't misled while debugging + self.train_loop.on_train_end() + except: print_exc() - finally: - # hook - self.train_loop.on_train_end() + raise def run_evaluation(self, on_epoch=False): if not (self.evaluating or self.sanity_checking): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 8ad603a7677ea..07fce6030f5d6 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -29,7 +29,7 @@ @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_multiple_optimizers_manual(tmpdir): +def test_multiple_optimizers_manual_no_return(tmpdir): """ Tests that only training_step can be used """ @@ -68,8 +68,9 @@ def training_step(self, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -279,8 +280,9 @@ def training_step(self, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -310,7 +312,7 @@ def configure_optimizers(self): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @RunIf(min_gpus=1, amp_apex=True) -def test_multiple_optimizers_manual_apex(tmpdir): +def test_multiple_optimizers_manual_apex_no_return(tmpdir): """ Tests that only training_step can be used """ @@ -353,8 +355,9 @@ def training_step(self, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -724,10 +727,6 @@ def optimizer_closure(): weight_after = self.layer.weight.clone() assert not torch.equal(weight_before, weight_after) - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer @@ -788,10 +787,6 @@ def optimizer_closure(): else: assert self.layer.weight.grad is not None - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer @@ -845,10 +840,6 @@ def optimizer_closure(): opt.step(closure=optimizer_closure) opt.zero_grad() - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer @@ -923,10 +914,6 @@ def dis_closure(): opt_dis.step(closure=dis_closure) opt_dis.zero_grad() - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) @@ -1031,10 +1018,6 @@ def dis_closure(): if make_dis_optimizer_step: opt_dis.step(closure=dis_closure) - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 5f0ca34015df0..24b32c8725963 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -134,8 +134,9 @@ def training_step(self, batch, batch_idx): opt_b.zero_grad() def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 model = TestModel() model.val_dataloader = None From 1f1201858e736fe3eadd4a0a5c3c00835307778d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 6 Apr 2021 23:45:19 -0700 Subject: [PATCH 29/40] draft fix v1 --- pytorch_lightning/trainer/trainer.py | 11 ++++++++--- tests/trainer/logging_/test_logger_connector.py | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 468f716d0455b..f364e949db8d4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -412,7 +412,7 @@ def fit( # we reuse fit for other functions. When already set, it shouldn't be modified. if not self.state.running: self.state = TrainerState.FITTING - if self._running_stage is None: + if self._running_stage is None or self._running_stage == RunningStage.TUNING: self.training = True # set local properties on the model @@ -625,8 +625,13 @@ def run_train(self) -> None: if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() - self.train_loop.on_train_end() + self.accelerator.on_train_end() + self._running_stage = None except: + # give accelerators a chance to finish + self.accelerator.on_train_end() + # reset bookkeeping + self._running_stage = None print_exc() raise diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 923821a5e50e4..42105a69596bd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -171,8 +171,10 @@ def train_dataloader(self): sampler=None, ) - def training_step_end(self, *_): + def training_step_end(self, training_step_output): self.train_results = deepcopy(self.trainer.logger_connector.cached_results) + # must return + return training_step_output model = TestModel() model.training_epoch_end = None From 834aa53b87bfddfd5478d7a45cd2f16dcd6bcfe1 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 6 Apr 2021 23:59:02 -0700 Subject: [PATCH 30/40] remove test related env flags --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 +-- tests/core/test_metric_result_integration.py | 3 --- tests/core/test_results.py | 4 +--- tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 +-- 6 files changed, 5 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 985b849d716fe..126afc9be6040 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,7 +33,6 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -import numpy log = logging.getLogger(__name__) @@ -79,7 +78,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" + # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..9586344d8c0d9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,12 +26,11 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group @@ -52,7 +51,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index ca9c393692be2..6bad31634ce83 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 46be491315c502c21051c646489ee50374294605 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 7 Apr 2021 20:06:42 -0700 Subject: [PATCH 31/40] comments --- pytorch_lightning/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 582da43a3655f..d104dfd17feb1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -420,7 +420,7 @@ def fit( # we reuse fit for other functions. When already set, it shouldn't be modified. if not self.state.running: self.state = TrainerState.FITTING - if self._running_stage is None or self._running_stage == RunningStage.TUNING: + if self._running_stage is None or self.tuning: self.training = True # set local properties on the model @@ -643,7 +643,6 @@ def run_train(self) -> None: # reset bookkeeping self._running_stage = None print_exc() - raise def run_evaluation(self, on_epoch=False): if not (self.evaluating or self.sanity_checking): From 0aff9a3a0a2507710849e56cce5b042525d7d2e6 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 7 Apr 2021 20:14:15 -0700 Subject: [PATCH 32/40] add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d2629d3928f0..764e6e714b33d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -220,6 +220,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730)) +- Fixed bug for trainer error handling which would cause hang for distributed training ([#6864](https://github.com/PyTorchLightning/pytorch-lightning/pull/6864)) ## [1.2.6] - 2021-03-30 From 7ddb1965b9f9722120953ba5d4d020169bae8ee0 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 7 Apr 2021 22:00:44 -0700 Subject: [PATCH 33/40] fix --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d104dfd17feb1..c5dd7f00428f8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -642,7 +642,7 @@ def run_train(self) -> None: self.accelerator.on_train_end() # reset bookkeeping self._running_stage = None - print_exc() + raise def run_evaluation(self, on_epoch=False): if not (self.evaluating or self.sanity_checking): From ba50f5d9cd3bdf62d17b42d7b33b0fc5ac82a4ce Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 8 Apr 2021 00:02:30 -0700 Subject: [PATCH 34/40] formatting issue --- pytorch_lightning/trainer/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c5dd7f00428f8..5d6bd00bb1df4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -16,7 +16,6 @@ import warnings from itertools import count from pathlib import Path -from traceback import print_exc from typing import Any, Dict, Iterable, List, Optional, Union import torch @@ -53,7 +52,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -637,7 +636,7 @@ def run_train(self) -> None: self.on_keyboard_interrupt() self.accelerator.on_train_end() self._running_stage = None - except: + except BaseException: # give accelerators a chance to finish self.accelerator.on_train_end() # reset bookkeeping From 2644953d686b6e9122ba9bcd1a5841a3b48ed24e Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Fri, 9 Apr 2021 16:08:40 -0700 Subject: [PATCH 35/40] modify trainer doc --- docs/source/common/trainer.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 96e19a7be4694..0a93baeee98b5 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -144,8 +144,8 @@ So you can run it like so: .. note:: If you want to stop a training run early, you can press "Ctrl + C" on your keyboard. The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown, including - running callbacks such as ``on_train_end``. The trainer object will also set an attribute - ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute + running accelerator callback ``on_train_end`` to clean up memory. The trainer object will also set + an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs. ------------ From b8f877a75ebe0d37dbeb0b27711557d5cd9867cd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Apr 2021 22:06:58 +0200 Subject: [PATCH 36/40] Update CHANGELOG --- CHANGELOG.md | 7 ++++++- pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 764e6e714b33d..2edd536c80ef9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -105,6 +105,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) +- The trainer no longer tries to save a checkpoint on exception or run callback's `on_train_end` functions ([#6864](https://github.com/PyTorchLightning/pytorch-lightning/pull/6864)) + + - Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) @@ -218,9 +221,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) +- Fixed bug for trainer error handling which would cause hang for distributed training ([#6864](https://github.com/PyTorchLightning/pytorch-lightning/pull/6864)) + + - Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730)) -- Fixed bug for trainer error handling which would cause hang for distributed training ([#6864](https://github.com/PyTorchLightning/pytorch-lightning/pull/6864)) ## [1.2.6] - 2021-03-30 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5d6bd00bb1df4..c2fcaac0b018b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -634,6 +634,7 @@ def run_train(self) -> None: if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() + # same treatment as below self.accelerator.on_train_end() self._running_stage = None except BaseException: From b11bd931e6dcf79e0504f8f5e09cec0650057279 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Apr 2021 22:41:13 +0200 Subject: [PATCH 37/40] Fix test --- tests/callbacks/test_callback_hook_outputs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index d1bcee43b1f02..7c5a6c03766dc 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -35,8 +35,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert 'x' in outputs def on_train_epoch_end(self, trainer, pl_module, outputs): - d = outputs[0] - assert len(d) == trainer.num_training_batches + assert len(outputs) == trainer.num_training_batches class TestModel(BoringModel): From 2bb6c9924bc01bd58b4dfcb70c8de14b7bb49385 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 13 Apr 2021 19:47:31 -0700 Subject: [PATCH 38/40] fix test for test_training_loop --- tests/trainer/test_training_loop.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/trainer/test_training_loop.py b/tests/trainer/test_training_loop.py index e8d5fcd4c3b95..24729355f3689 100644 --- a/tests/trainer/test_training_loop.py +++ b/tests/trainer/test_training_loop.py @@ -21,6 +21,7 @@ def test_training_loop_hook_call_order(tmpdir): https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks""" class HookedModel(BoringModel): + def __init__(self): super().__init__() self.called = [] @@ -58,15 +59,15 @@ def on_after_backward(self): super().on_after_backward() def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu, - using_native_amp, - using_lbfgs, + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu, + using_native_amp, + using_lbfgs, ): super().optimizer_step( epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs @@ -131,16 +132,18 @@ def test_outputs_format(tmpdir): """Tests that outputs objects passed to model hooks and methods are consistent and in the correct format.""" class HookedModel(BoringModel): + def training_step(self, batch, batch_idx): - self.log("foo", "bar") - return super().training_step(batch, batch_idx) + output = super().training_step(batch, batch_idx) + self.log("foo", 123) + output["foo"] = 123 + return output @staticmethod def _check_output(output): assert "loss" in output - assert "foo" in output - assert output["foo"] == "bar" + assert output["foo"] == 123 def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): HookedModel._check_output(outputs) From c8c8b08c2340dfff17159dbe94e234486cadbef8 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 13 Apr 2021 19:53:16 -0700 Subject: [PATCH 39/40] formatting --- tests/trainer/test_training_loop.py | 49 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/trainer/test_training_loop.py b/tests/trainer/test_training_loop.py index 24729355f3689..1349659cc4595 100644 --- a/tests/trainer/test_training_loop.py +++ b/tests/trainer/test_training_loop.py @@ -21,7 +21,6 @@ def test_training_loop_hook_call_order(tmpdir): https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks""" class HookedModel(BoringModel): - def __init__(self): super().__init__() self.called = [] @@ -70,9 +69,18 @@ def optimizer_step( using_lbfgs, ): super().optimizer_step( - epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu, + using_native_amp, + using_lbfgs, ) - self.called.append("optimizer_step") # append after as closure calls other methods + self.called.append( + "optimizer_step" + ) # append after as closure calls other methods def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.called.append("on_train_batch_end") @@ -107,23 +115,23 @@ def on_epoch_end(self): trainer.fit(model) expected = [ - 'on_epoch_start', # validation - 'on_epoch_end', - 'on_epoch_start', # training - 'on_train_epoch_start', - 'on_train_batch_start', - 'training_step', - 'on_before_zero_grad', - 'optimizer_zero_grad', - 'backward', - 'on_after_backward', - 'optimizer_step', - 'on_train_batch_end', - 'training_epoch_end', - 'on_train_epoch_end', - 'on_epoch_end', - 'on_epoch_start', # validation - 'on_epoch_end' + "on_epoch_start", # validation + "on_epoch_end", + "on_epoch_start", # training + "on_train_epoch_start", + "on_train_batch_start", + "training_step", + "on_before_zero_grad", + "optimizer_zero_grad", + "backward", + "on_after_backward", + "optimizer_step", + "on_train_batch_end", + "training_epoch_end", + "on_train_epoch_end", + "on_epoch_end", + "on_epoch_start", # validation + "on_epoch_end", ] assert model.called == expected @@ -132,7 +140,6 @@ def test_outputs_format(tmpdir): """Tests that outputs objects passed to model hooks and methods are consistent and in the correct format.""" class HookedModel(BoringModel): - def training_step(self, batch, batch_idx): output = super().training_step(batch, batch_idx) self.log("foo", 123) From d8ca310926dcfefc1b5d3e1ab4d6adf7e161900d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 14 Apr 2021 11:28:32 +0100 Subject: [PATCH 40/40] Fix broken test --- tests/trainer/optimization/test_manual_optimization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 004be23142622..4c8cf99a275f0 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -641,6 +641,8 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() + return {'loss1': loss_1, 'loss2': loss_2} + def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer assert len(outputs) == 2