From 89f284d6fbafcf6aadac7abf1af08ee3fba39865 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 23 Mar 2021 12:06:24 -0700 Subject: [PATCH 01/52] Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 4 +++- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 ++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ea1efd6e15873..1383964cbc789 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9586344d8c0d9..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,11 +26,12 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group @@ -51,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 536c1323b0e6715fb5919196ea48b0fcddddcd66 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 01:17:20 -0700 Subject: [PATCH 02/52] checkpoint consolidation --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:37:52 -0700 Subject: [PATCH 03/52] Update ddp_spawn.py --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From bf70e431b3ce4893de804e0f3b5d59e79346d6d7 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:41:33 -0700 Subject: [PATCH 04/52] Update test_metric_result_integration.py --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From ea749068785bbad689a12066544893b1605f20c5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:42:16 -0700 Subject: [PATCH 05/52] Update test_results.py --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From a9aae99f6ed6e9388ecf1d8a7bd79966176a65af Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:04 -0700 Subject: [PATCH 06/52] Update utils.py --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): From 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:43 -0700 Subject: [PATCH 07/52] Update utils.py --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 0d23d75bc91e4e0b7805712e394cb093fac22841 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:44:18 -0700 Subject: [PATCH 08/52] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From ca6f98ba8ff835368ae3ef91e435e4d4f458c45b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:51:55 -0700 Subject: [PATCH 09/52] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:57:31 -0700 Subject: [PATCH 10/52] Update test_results.py --- tests/core/test_results.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..334062ae994a2 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,6 +30,7 @@ def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From 7635b4f47bcef43b9bbe677ad96a3bad135246a5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:10 -0700 Subject: [PATCH 11/52] Revert "Update test_results.py" This reverts commit 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1. --- tests/core/test_results.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 334062ae994a2..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,7 +30,6 @@ def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From d64f90cbc748de193a02237acd6ac686750b82d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:20 -0700 Subject: [PATCH 12/52] Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da789f9d04d2c967a65adf4fb026dc134b8, reversing changes made to 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From dcdcd29731061c919b15ab0b56669259817a81c4 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:36 -0700 Subject: [PATCH 13/52] Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 8651d54d79396eaaba16d7eb1e769a1e91d5702e Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:40 -0700 Subject: [PATCH 14/52] Revert "Update utils.py" This reverts commit 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 15f4b9e59bec52b07dddb55eeda4d9a68b8bd6d2 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:45 -0700 Subject: [PATCH 15/52] Revert "Update utils.py" This reverts commit a9aae99f6ed6e9388ecf1d8a7bd79966176a65af. --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): From 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:48 -0700 Subject: [PATCH 16/52] Revert "Update test_results.py" This reverts commit ea749068785bbad689a12066544893b1605f20c5. --- tests/core/test_results.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From 6c095b2370a2afe9d24918a5798ce1ebffed7e0d Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:52 -0700 Subject: [PATCH 17/52] Revert "Update test_metric_result_integration.py" This reverts commit bf70e431b3ce4893de804e0f3b5d59e79346d6d7. --- tests/core/test_metric_result_integration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From 8222dc98ead37d961a52b7366070aa10f66d92d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:55 -0700 Subject: [PATCH 18/52] Revert "Update ddp_spawn.py" This reverts commit f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..941025b36c0ac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:17:01 -0700 Subject: [PATCH 19/52] Revert "checkpoint consolidation" This reverts commit 536c1323b0e6715fb5919196ea48b0fcddddcd66. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 7a369f47e1a94d701fce48c994cc3f2da266dad0 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:19:37 -0700 Subject: [PATCH 20/52] Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac. --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From b4a0b9e9e1e0a08e50979facc2f0fc74187de2ee Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:32:43 -0700 Subject: [PATCH 21/52] Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f47e1a94d701fce48c994cc3f2da266dad0. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 0ce7e056ac47436bc727f91f8eed335fc736696c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:44 -0700 Subject: [PATCH 22/52] Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc98ead37d961a52b7366070aa10f66d92d1. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From fe9736d94bfcac3e084eec2d63e62351d3618175 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:49 -0700 Subject: [PATCH 23/52] Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2370a2afe9d24918a5798ce1ebffed7e0d. --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From c314ef6d30373c2c94fdeceef6ee7b9d961a48c9 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:56 -0700 Subject: [PATCH 24/52] Revert "Revert "Update test_results.py"" This reverts commit 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b. --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From c3feda03d7fbb25dcf1917e718209f98d0503327 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:05 -0700 Subject: [PATCH 25/52] Revert "Revert "Update utils.py"" This reverts commit 8651d54d79396eaaba16d7eb1e769a1e91d5702e. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From c759477a0a9462f812a880a8cee7c09b3f432520 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:13 -0700 Subject: [PATCH 26/52] Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29731061c919b15ab0b56669259817a81c4. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 4e67db2a1fed55e6d6e3fa09766aaa55c7995ca1 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 10:57:58 -0700 Subject: [PATCH 27/52] modify distributed environment to make test pass --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 ++- tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 3 +++ tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 2 +- 6 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..0b4b7680076a3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything +import numpy log = logging.getLogger(__name__) @@ -78,7 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From fffecb8ff05140ef33a4b833a3f0d44b53f40097 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 14 Apr 2021 17:44:43 -0700 Subject: [PATCH 28/52] rfc --- .../plugins/training_type/ddp.py | 71 +++++++++-- .../plugins/training_type/ddp2.py | 5 +- .../plugins/training_type/ddp_spawn.py | 7 +- pytorch_lightning/plugins/training_type/dp.py | 4 + .../plugins/training_type/horovod.py | 4 + .../plugins/training_type/parallel.py | 16 ++- .../plugins/training_type/rpc.py | 2 +- .../plugins/training_type/tpu_spawn.py | 4 + .../connectors/accelerator_connector.py | 113 ++++++++++-------- tests/plugins/test_cluster_integration.py | 17 +-- 10 files changed, 168 insertions(+), 75 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7e9624d9a0122..7e37240ce7483 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -72,22 +72,54 @@ def __init__( ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] - self.num_nodes = num_nodes + self._num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self.dist = LightningDistributed() self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None - self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper + # world ranks is related to num_nodes, cluster_environment and parallel_devices + # when resetting these parameters, need to reset world ranks self.set_world_ranks() @property def root_device(self): return self.parallel_devices[self.local_rank] + @property + def num_nodes(self): + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, x: int): + self._num_nodes = x + self.set_world_ranks() + + @property + def parallel_devices(self): + return self._parallel_devices + + @parallel_devices.setter + def parallel_devices(self, parallel_devices: List[torch.device]): + self._parallel_devices = parallel_devices + self.set_world_ranks() + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def cluster_environment(self): + return self._cluster_environment + + @cluster_environment.setter + def cluster_environment(self, cluster_environment: ClusterEnvironment): + self._cluster_environment = cluster_environment + self.set_world_ranks() + @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @@ -99,7 +131,7 @@ def _is_single_process_single_device(self) -> bool: def setup_environment(self): # start the other scripts - if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": + if (not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1"): self._call_children_scripts() # set the task idx @@ -159,7 +191,7 @@ def _call_children_scripts(self): env_copy["LOCAL_RANK"] = f"{local_rank}" # remove env var if global seed not set - if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + if (os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy): del env_copy["PL_GLOBAL_SEED"] # start process @@ -169,7 +201,10 @@ def _call_children_scripts(self): if HydraConfig.initialized(): cwd = get_original_cwd() os_cwd = f'"{os.getcwd()}"' - command += [f'hydra.run.dir={os_cwd}', f'hydra.job.name=train_ddp_process_{local_rank}'] + command += [ + f"hydra.run.dir={os_cwd}", + f"hydra.job.name=train_ddp_process_{local_rank}", + ] proc = subprocess.Popen(command, env=env_copy, cwd=cwd) self.interactive_ddp_procs.append(proc) @@ -226,8 +261,9 @@ def pre_configure_ddp(self): # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization - if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( - "find_unused_parameters", False + if ( + _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization + and not self._ddp_kwargs.get("find_unused_parameters", False) ): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " @@ -261,8 +297,8 @@ def determine_ddp_device_ids(self): return [self.root_device.index] def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None: - global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() - world_size = world_size if world_size is not None else self.cluster_environment.world_size() + global_rank = (global_rank if global_rank is not None else self.cluster_environment.global_rank()) + world_size = (world_size if world_size is not None else self.cluster_environment.world_size()) os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) if not torch.distributed.is_initialized(): @@ -291,9 +327,15 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def pre_backward( + self, + closure_loss: torch.Tensor, + should_accumulate: bool, + optimizer: Optimizer, + opt_idx: int, + ): """Run before precision plugin executes backward""" - if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: + if (not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync): prepare_for_backward(self.model, closure_loss) def model_to_device(self): @@ -301,7 +343,12 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + def reduce( + self, + tensor, + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ): """ Reduces a tensor from several distributed processes to one aggregated tensor. diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index d7c3d84184926..66c15915076b3 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -72,5 +72,6 @@ def _is_single_process_single_device(self) -> bool: return False def set_world_ranks(self): - self.cluster_environment.set_global_rank(self.node_rank) - self.cluster_environment.set_world_size(self.num_nodes) + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank) + self.cluster_environment.set_world_size(self.num_nodes) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 719612cd908c8..5ef436edb26bf 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -16,6 +16,7 @@ import re from typing import Any, Dict, List, Optional, Union +import numpy import torch import torch.distributed as torch_distrib import torch.multiprocessing as mp @@ -33,7 +34,6 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -import numpy if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -61,7 +61,6 @@ def __init__( self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() - self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook @@ -82,6 +81,10 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__ = state + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + @property def root_device(self): return self.parallel_devices[self.local_rank] diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 131a134ca724d..06b2aa64268ed 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -27,6 +27,10 @@ class DataParallelPlugin(ParallelPlugin): def __init__(self, parallel_devices: Optional[List[torch.device]]): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + @property + def is_cluster_environment_resettable(self): + return False + @property def global_rank(self) -> int: return 0 diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index cf2ada0e6d9a7..f4b8434082e53 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -33,6 +33,10 @@ def __init__(self, parallel_devices: Optional[List[torch.device]] = None): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) rank_zero_only.rank = self.global_rank + @property + def is_cluster_environment_resettable(self): + return False + @property def global_rank(self) -> int: return hvd.rank() diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 023bdcd0172ff..da50079800442 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -34,14 +34,26 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, ): super().__init__() - self.parallel_devices = parallel_devices - self.cluster_environment = cluster_environment + self._parallel_devices = parallel_devices + self._cluster_environment = cluster_environment @property @abstractmethod def root_device(self): raise NotImplementedError + @property + def parallel_devices(self): + return self._parallel_devices + + @property + def cluster_environment(self): + return self._cluster_environment + + @property + def is_cluster_environment_resettable(self): + return True + @property def on_gpu(self): return self.root_device.type == "cuda" and torch.cuda.is_available() diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 3e0f57daef001..c2c2f3c257304 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -42,7 +42,7 @@ def __init__( self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: Optional[int] = None, + num_nodes: int = 1, cluster_environment: Optional[ClusterEnvironment] = None, sync_batchnorm: Optional[bool] = None, **kwargs diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7132432976491..ab0f84401fd93 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -54,6 +54,10 @@ def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[ self.tpu_global_core_rank = 0 self.start_method = None + @property + def is_cluster_environment_resettable(self): + return False + @property def global_rank(self) -> int: return self.tpu_local_core_rank diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 1f086bbee8ca3..87596ba2b317a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -148,7 +148,8 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp def handle_given_plugins( - self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]] + self, + plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]], ): plugins = plugins if plugins is not None else [] @@ -175,16 +176,16 @@ def handle_given_plugins( else: raise MisconfigurationException( - 'You can only specify one precision and one training type plugin.' - f' Found more than 1 training type plugin: {type(plug).__name__}' + "You can only specify one precision and one training type plugin." + f" Found more than 1 training type plugin: {type(plug).__name__}" ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( - 'You can only specify one precision and one training type plugin.' - f' Found more than 1 precision plugin: {type(plug).__name__}' + "You can only specify one precision and one training type plugin." + f" Found more than 1 precision plugin: {type(plug).__name__}" ) elif isinstance(plug, ClusterEnvironment): @@ -192,16 +193,16 @@ def handle_given_plugins( cluster_environment = plug else: raise MisconfigurationException( - 'You can only specify one cluster environment. Found more than 1 cluster environment plugin' + "You can only specify one cluster environment. Found more than 1 cluster environment plugin" ) else: raise MisconfigurationException( - f'Found invalid type for plugin {plug}. Expected a precision or training type plugin.' + f"Found invalid type for plugin {plug}. Expected a precision or training type plugin." ) self._training_type_plugin = training_type self._precision_plugin = precision - self._cluster_environment = cluster_environment or self.select_cluster_environment() + self._cluster_environment = (cluster_environment or self.select_cluster_environment()) @property def precision_plugin(self) -> PrecisionPlugin: @@ -249,8 +250,12 @@ def use_dp(self) -> bool: @property def use_ddp(self) -> bool: return self._distrib_type in ( - DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.DDP_SHARDED, + DistributedType.DDP_SHARDED_SPAWN, + DistributedType.DEEPSPEED, + DistributedType.TPU_SPAWN, ) @property @@ -269,7 +274,7 @@ def use_deepspeed(self) -> bool: def is_distributed(self) -> bool: # Used for custom plugins. # Custom plugins should implement is_distributed property. - if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: + if hasattr(self.training_type_plugin, "is_distributed") and not self.on_tpu: return self.training_type_plugin.is_distributed is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod if self.on_tpu: @@ -298,7 +303,7 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: @property def root_gpu(self) -> Optional[int]: - return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None + return (self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None) @property def is_using_torchelastic(self) -> bool: @@ -336,8 +341,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: "You have asked for native AMP on CPU, but AMP is only available on GPU." ) elif not _NATIVE_AMP_AVAILABLE: - msg = "You have asked for native AMP but your PyTorch version does not support it." \ - " Consider upgrading with `pip install torch>=1.6`." + msg = ( + "You have asked for native AMP but your PyTorch version does not support it." + " Consider upgrading with `pip install torch>=1.6`." + ) if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." @@ -346,7 +353,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if isinstance( + self.training_type_plugin, + (DDPShardedPlugin, DDPSpawnShardedPlugin), + ): return ShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() @@ -378,18 +388,18 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = DeepSpeedPlugin( num_nodes=self.num_nodes, cluster_environment=self.select_cluster_environment(), - parallel_devices=self.parallel_devices + parallel_devices=self.parallel_devices, ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks - use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() + use_torchelastic_ddp = (self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()) use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu - use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN - use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() + use_tpu_spawn = (self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN) + use_ddp_cpu_torch_elastic = (use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()) use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED - use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN + use_ddp_sharded_spawn = (self._distrib_type == DistributedType.DDP_SHARDED_SPAWN) # TODO: decouple from TE # ddp script mode uses the same flags as TE @@ -402,7 +412,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: ddp_plugin_cls = DDPSpawnShardedPlugin - elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: + elif (use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp): ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin @@ -428,20 +438,21 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: # necessary for when the user has passed in a plugin - if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + if hasattr(training_type, "parallel_devices") and not getattr(training_type, "parallel_devices"): training_type.parallel_devices = self.parallel_devices - if hasattr(training_type, 'num_processes'): - training_type.num_processes = len(self.parallel_devices) - if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: + if ( + hasattr(training_type, "cluster_environment") and getattr(training_type, "cluster_environment") is None + and training_type.is_cluster_environment_resettable + ): training_type.cluster_environment = self.select_cluster_environment() - if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: + if hasattr(training_type, "num_nodes"): + # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes - # Automatically set sync_batchnorm if None. - # Useful for custom plugins. - if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: + if hasattr(training_type, "sync_batchnorm"): + # Set sync_batchnorm for training_type from trainer setting. training_type.sync_batchnorm = self.sync_batchnorm return training_type @@ -449,11 +460,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra def select_accelerator(self) -> Accelerator: if isinstance(self.distributed_backend, Accelerator): # custom accelerator from user - if self._precision_plugin is not None or self._training_type_plugin is not None: + if (self._precision_plugin is not None or self._training_type_plugin is not None): # plugins also specified by user rank_zero_warn( - 'Specified `Precision` and `TrainingType` plugins will be ignored,' - ' since an `Accelerator` instance was provided.' + "Specified `Precision` and `TrainingType` plugins will be ignored," + " since an `Accelerator` instance was provided." ) return self.distributed_backend @@ -495,7 +506,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType.DDP elif self.num_gpus > 1: rank_zero_warn( - 'You requested multiple GPUs but did not specify a backend, e.g.' + "You requested multiple GPUs but did not specify a backend, e.g." ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' ) self.distributed_backend = "ddp_spawn" @@ -505,14 +516,14 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType.DDP if self.num_gpus > 0: rank_zero_warn( - 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' + "You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs." ) self.parallel_device_ids = None if self.num_processes is None: # define the max CPU available self.num_processes = os.cpu_count() # special case with TPUs - elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: + elif self.distributed_backend == "tpu" or self.tpu_cores is not None: self._device_type = DeviceType.TPU if isinstance(self.tpu_cores, int): self._distrib_type = DistributedType.TPU_SPAWN @@ -520,30 +531,35 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType(self.distributed_backend) # unless you request explicitly for CPU and some GPU are available use them - _on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend + _on_cpu = self.distributed_backend and "cpu" in self.distributed_backend if self.num_gpus > 0 and not _on_cpu: self._device_type = DeviceType.GPU - _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + _gpu_distrib_types = ( + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.DDP2, + ) # DP and DDP2 cannot run without GPU - if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu: + if (self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu): rank_zero_warn( - 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' + "You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`." ) # todo: in some cases it yield in comarison None and int if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): self._distrib_type = DistributedType.DDP else: - rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') + rank_zero_warn("You are running on single node with no parallelization, so distributed has no effect.") self._distrib_type = None # finished configuring self._distrib_type, check ipython environment self.check_interactive_compatibility() # for DDP overwrite nb processes by requested GPUs - if ( - self._device_type == DeviceType.GPU - and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + if self._device_type == DeviceType.GPU and self._distrib_type in ( + DistributedType.DDP, + DistributedType.DDP_SPAWN, ): self.num_processes = self.num_gpus @@ -558,13 +574,13 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if self.num_nodes > 1 and not using_valid_distributed: # throw error to force user to choose a supported distributed type such as ddp or ddp2 raise MisconfigurationException( - 'Your chosen distributed type does not support num_nodes > 1. ' - 'Please set accelerator=ddp or accelerator=ddp2.' + "Your chosen distributed type does not support num_nodes > 1. " + "Please set accelerator=ddp or accelerator=ddp2." ) - rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') + rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}") num_cores = self.tpu_cores if self.tpu_cores is not None else 0 - rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') + rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores") if torch.cuda.is_available() and self._device_type != DeviceType.GPU: rank_zero_warn( @@ -590,7 +606,8 @@ def check_interactive_compatibility(self): is not compatible with an interactive environment """ from pytorch_lightning.utilities import _IS_INTERACTIVE - if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible(): + + if (_IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible()): raise MisconfigurationException( f"Selected distributed backend {self._distrib_type} is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible backends:" diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index 032276dd674d0..48367fcb8078c 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -47,13 +47,14 @@ def environment_combinations(): @pytest.mark.parametrize( - "plugin_cls", [ + "plugin_cls", + [ DDPPlugin, DDPShardedPlugin, DDP2Plugin, pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)), - ] + ], ) def test_ranks_available_manual_plugin_selection(plugin_cls): """ Test that the rank information is readily available after Trainer initialization. """ @@ -64,12 +65,11 @@ def test_ranks_available_manual_plugin_selection(plugin_cls): expected.update(global_rank=expected["node_rank"], world_size=num_nodes) with mock.patch.dict(os.environ, variables): - plugin = plugin_cls( - parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], + plugin = plugin_cls(parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)]) + trainer = Trainer( + plugins=[cluster, plugin], num_nodes=num_nodes, - cluster_environment=cluster, ) - trainer = Trainer(plugins=[plugin]) assert rank_zero_only.rank == expected["global_rank"] assert trainer.global_rank == expected["global_rank"] assert trainer.local_rank == expected["local_rank"] @@ -78,13 +78,14 @@ def test_ranks_available_manual_plugin_selection(plugin_cls): @pytest.mark.parametrize( - "trainer_kwargs", [ + "trainer_kwargs", + [ dict(accelerator="ddp", gpus=[1, 2]), dict(accelerator="ddp_sharded", gpus=[1, 2]), dict(accelerator="ddp2", gpus=[1, 2]), dict(accelerator="ddp_cpu", num_processes=2), dict(accelerator="ddp_spawn", gpus=[1, 2]), - ] + ], ) @mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.device_count", return_value=4) From 089e566b1447548685fde8c24a8883d657202633 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 14 Apr 2021 18:00:04 -0700 Subject: [PATCH 29/52] rebase --- tests/core/test_metric_result_integration.py | 3 --- tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 +-- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index ca9c393692be2..6bad31634ce83 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From bb8ed7796d60f4afa2c60887fcb2db4ad46e6d2d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 14 Apr 2021 18:06:15 -0700 Subject: [PATCH 30/52] formatting --- .../plugins/training_type/ddp.py | 37 ++----- .../connectors/accelerator_connector.py | 97 ++++++++----------- 2 files changed, 50 insertions(+), 84 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7e37240ce7483..0933a848fe6b0 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -131,7 +131,7 @@ def _is_single_process_single_device(self) -> bool: def setup_environment(self): # start the other scripts - if (not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1"): + if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() # set the task idx @@ -191,7 +191,7 @@ def _call_children_scripts(self): env_copy["LOCAL_RANK"] = f"{local_rank}" # remove env var if global seed not set - if (os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy): + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: del env_copy["PL_GLOBAL_SEED"] # start process @@ -201,10 +201,7 @@ def _call_children_scripts(self): if HydraConfig.initialized(): cwd = get_original_cwd() os_cwd = f'"{os.getcwd()}"' - command += [ - f"hydra.run.dir={os_cwd}", - f"hydra.job.name=train_ddp_process_{local_rank}", - ] + command += [f'hydra.run.dir={os_cwd}', f'hydra.job.name=train_ddp_process_{local_rank}'] proc = subprocess.Popen(command, env=env_copy, cwd=cwd) self.interactive_ddp_procs.append(proc) @@ -261,9 +258,8 @@ def pre_configure_ddp(self): # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization - if ( - _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization - and not self._ddp_kwargs.get("find_unused_parameters", False) + if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( + "find_unused_parameters", False ): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " @@ -297,8 +293,8 @@ def determine_ddp_device_ids(self): return [self.root_device.index] def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None: - global_rank = (global_rank if global_rank is not None else self.cluster_environment.global_rank()) - world_size = (world_size if world_size is not None else self.cluster_environment.world_size()) + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) if not torch.distributed.is_initialized(): @@ -327,15 +323,9 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) - def pre_backward( - self, - closure_loss: torch.Tensor, - should_accumulate: bool, - optimizer: Optimizer, - opt_idx: int, - ): + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" - if (not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync): + if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) def model_to_device(self): @@ -343,21 +333,14 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce( - self, - tensor, - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ): + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ Reduces a tensor from several distributed processes to one aggregated tensor. - Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. - Return: reduced value, except when the input was not a tensor the output remains is unchanged """ diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 87596ba2b317a..bb7b090bb2f5e 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -148,8 +148,7 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp def handle_given_plugins( - self, - plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]], + self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]] ): plugins = plugins if plugins is not None else [] @@ -176,16 +175,16 @@ def handle_given_plugins( else: raise MisconfigurationException( - "You can only specify one precision and one training type plugin." - f" Found more than 1 training type plugin: {type(plug).__name__}" + 'You can only specify one precision and one training type plugin.' + f' Found more than 1 training type plugin: {type(plug).__name__}' ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( - "You can only specify one precision and one training type plugin." - f" Found more than 1 precision plugin: {type(plug).__name__}" + 'You can only specify one precision and one training type plugin.' + f' Found more than 1 precision plugin: {type(plug).__name__}' ) elif isinstance(plug, ClusterEnvironment): @@ -193,16 +192,16 @@ def handle_given_plugins( cluster_environment = plug else: raise MisconfigurationException( - "You can only specify one cluster environment. Found more than 1 cluster environment plugin" + 'You can only specify one cluster environment. Found more than 1 cluster environment plugin' ) else: raise MisconfigurationException( - f"Found invalid type for plugin {plug}. Expected a precision or training type plugin." + f'Found invalid type for plugin {plug}. Expected a precision or training type plugin.' ) self._training_type_plugin = training_type self._precision_plugin = precision - self._cluster_environment = (cluster_environment or self.select_cluster_environment()) + self._cluster_environment = cluster_environment or self.select_cluster_environment() @property def precision_plugin(self) -> PrecisionPlugin: @@ -250,12 +249,8 @@ def use_dp(self) -> bool: @property def use_ddp(self) -> bool: return self._distrib_type in ( - DistributedType.DDP, - DistributedType.DDP_SPAWN, - DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, - DistributedType.DEEPSPEED, - DistributedType.TPU_SPAWN, + DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, + DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN ) @property @@ -274,7 +269,7 @@ def use_deepspeed(self) -> bool: def is_distributed(self) -> bool: # Used for custom plugins. # Custom plugins should implement is_distributed property. - if hasattr(self.training_type_plugin, "is_distributed") and not self.on_tpu: + if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: return self.training_type_plugin.is_distributed is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod if self.on_tpu: @@ -303,14 +298,13 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: @property def root_gpu(self) -> Optional[int]: - return (self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None) + return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None @property def is_using_torchelastic(self) -> bool: """ .. deprecated:: v1.3 Will be removed in v1.5.0. - Returns: ``True`` if the current process was launched using the torchelastic command. """ @@ -341,10 +335,8 @@ def select_precision_plugin(self) -> PrecisionPlugin: "You have asked for native AMP on CPU, but AMP is only available on GPU." ) elif not _NATIVE_AMP_AVAILABLE: - msg = ( - "You have asked for native AMP but your PyTorch version does not support it." - " Consider upgrading with `pip install torch>=1.6`." - ) + msg = "You have asked for native AMP but your PyTorch version does not support it." \ + " Consider upgrading with `pip install torch>=1.6`." if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." @@ -353,10 +345,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") - if isinstance( - self.training_type_plugin, - (DDPShardedPlugin, DDPSpawnShardedPlugin), - ): + if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): return ShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() @@ -388,18 +377,18 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = DeepSpeedPlugin( num_nodes=self.num_nodes, cluster_environment=self.select_cluster_environment(), - parallel_devices=self.parallel_devices, + parallel_devices=self.parallel_devices ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks - use_torchelastic_ddp = (self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()) + use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu - use_tpu_spawn = (self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN) - use_ddp_cpu_torch_elastic = (use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()) + use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN + use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED - use_ddp_sharded_spawn = (self._distrib_type == DistributedType.DDP_SHARDED_SPAWN) + use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN # TODO: decouple from TE # ddp script mode uses the same flags as TE @@ -412,7 +401,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: ddp_plugin_cls = DDPSpawnShardedPlugin - elif (use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp): + elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin @@ -460,11 +449,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra def select_accelerator(self) -> Accelerator: if isinstance(self.distributed_backend, Accelerator): # custom accelerator from user - if (self._precision_plugin is not None or self._training_type_plugin is not None): + if self._precision_plugin is not None or self._training_type_plugin is not None: # plugins also specified by user rank_zero_warn( - "Specified `Precision` and `TrainingType` plugins will be ignored," - " since an `Accelerator` instance was provided." + 'Specified `Precision` and `TrainingType` plugins will be ignored,' + ' since an `Accelerator` instance was provided.' ) return self.distributed_backend @@ -506,7 +495,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType.DDP elif self.num_gpus > 1: rank_zero_warn( - "You requested multiple GPUs but did not specify a backend, e.g." + 'You requested multiple GPUs but did not specify a backend, e.g.' ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' ) self.distributed_backend = "ddp_spawn" @@ -516,14 +505,14 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType.DDP if self.num_gpus > 0: rank_zero_warn( - "You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs." + 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' ) self.parallel_device_ids = None if self.num_processes is None: # define the max CPU available self.num_processes = os.cpu_count() # special case with TPUs - elif self.distributed_backend == "tpu" or self.tpu_cores is not None: + elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: self._device_type = DeviceType.TPU if isinstance(self.tpu_cores, int): self._distrib_type = DistributedType.TPU_SPAWN @@ -531,35 +520,30 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType(self.distributed_backend) # unless you request explicitly for CPU and some GPU are available use them - _on_cpu = self.distributed_backend and "cpu" in self.distributed_backend + _on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend if self.num_gpus > 0 and not _on_cpu: self._device_type = DeviceType.GPU - _gpu_distrib_types = ( - DistributedType.DP, - DistributedType.DDP, - DistributedType.DDP_SPAWN, - DistributedType.DDP2, - ) + _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) # DP and DDP2 cannot run without GPU - if (self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu): + if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu: rank_zero_warn( - "You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`." + 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) # todo: in some cases it yield in comarison None and int if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): self._distrib_type = DistributedType.DDP else: - rank_zero_warn("You are running on single node with no parallelization, so distributed has no effect.") + rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') self._distrib_type = None # finished configuring self._distrib_type, check ipython environment self.check_interactive_compatibility() # for DDP overwrite nb processes by requested GPUs - if self._device_type == DeviceType.GPU and self._distrib_type in ( - DistributedType.DDP, - DistributedType.DDP_SPAWN, + if ( + self._device_type == DeviceType.GPU + and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) ): self.num_processes = self.num_gpus @@ -574,13 +558,13 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if self.num_nodes > 1 and not using_valid_distributed: # throw error to force user to choose a supported distributed type such as ddp or ddp2 raise MisconfigurationException( - "Your chosen distributed type does not support num_nodes > 1. " - "Please set accelerator=ddp or accelerator=ddp2." + 'Your chosen distributed type does not support num_nodes > 1. ' + 'Please set accelerator=ddp or accelerator=ddp2.' ) - rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}") + rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') num_cores = self.tpu_cores if self.tpu_cores is not None else 0 - rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores") + rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') if torch.cuda.is_available() and self._device_type != DeviceType.GPU: rank_zero_warn( @@ -606,8 +590,7 @@ def check_interactive_compatibility(self): is not compatible with an interactive environment """ from pytorch_lightning.utilities import _IS_INTERACTIVE - - if (_IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible()): + if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible(): raise MisconfigurationException( f"Selected distributed backend {self._distrib_type} is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible backends:" From 6b7fe6f8afd632f3056144ff6c1bb5d2d02ca771 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 14 Apr 2021 18:10:35 -0700 Subject: [PATCH 31/52] more nits --- pytorch_lightning/plugins/training_type/ddp.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - tests/core/test_results.py | 3 --- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 0933a848fe6b0..81a3585ceb4a9 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -94,8 +94,8 @@ def num_nodes(self): return self._num_nodes @num_nodes.setter - def num_nodes(self, x: int): - self._num_nodes = x + def num_nodes(self, num_nodes: int): + self._num_nodes = num_nodes self.set_world_ranks() @property diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 486901e6232be..f3d5fd0b56832 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -16,7 +16,6 @@ import re from typing import Any, Dict, List, Optional, Union -import numpy import torch import torch.distributed as torch_distrib import torch.multiprocessing as mp diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From 90fa8e0266e30a093ba0675dce5f8e3d6089ab7f Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 14 Apr 2021 18:13:01 -0700 Subject: [PATCH 32/52] nit --- tests/core/test_results.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..9586344d8c0d9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,6 +30,7 @@ def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From ba4f9c4df55c978a48afe131398143f9eeb07369 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 15 Apr 2021 00:01:31 -0700 Subject: [PATCH 33/52] split, setting num_nodes and sync batchnorm only --- .../plugins/training_type/ddp.py | 29 ++++--------------- .../plugins/training_type/ddp_spawn.py | 18 ++++++++---- pytorch_lightning/plugins/training_type/dp.py | 4 --- .../plugins/training_type/horovod.py | 4 --- .../plugins/training_type/parallel.py | 16 ++-------- .../plugins/training_type/tpu_spawn.py | 4 --- .../connectors/accelerator_connector.py | 13 ++++----- tests/plugins/test_cluster_integration.py | 7 +++-- 8 files changed, 31 insertions(+), 64 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 81a3585ceb4a9..460889261f992 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -75,14 +75,15 @@ def __init__( self._num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self.dist = LightningDistributed() + self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper - # world ranks is related to num_nodes, cluster_environment and parallel_devices - # when resetting these parameters, need to reset world ranks + # note that world ranks is related to num_nodes, when resetting these parameters, + # need to reset world ranks self.set_world_ranks() @property @@ -98,28 +99,6 @@ def num_nodes(self, num_nodes: int): self._num_nodes = num_nodes self.set_world_ranks() - @property - def parallel_devices(self): - return self._parallel_devices - - @parallel_devices.setter - def parallel_devices(self, parallel_devices: List[torch.device]): - self._parallel_devices = parallel_devices - self.set_world_ranks() - - @property - def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 - - @property - def cluster_environment(self): - return self._cluster_environment - - @cluster_environment.setter - def cluster_environment(self, cluster_environment: ClusterEnvironment): - self._cluster_environment = cluster_environment - self.set_world_ranks() - @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @@ -336,11 +315,13 @@ def model_to_device(self): def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ Reduces a tensor from several distributed processes to one aggregated tensor. + Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. + Return: reduced value, except when the input was not a tensor the output remains is unchanged """ diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index f3d5fd0b56832..a1a37f8f28038 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -60,15 +60,27 @@ def __init__( **kwargs: Union[Any, Dict[str, Any]], ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) - self.num_nodes = num_nodes + self._num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() + self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._local_rank = 0 + # note that world ranks is related to num_nodes, when resetting these parameters, + # need to reset world ranks + self.set_world_ranks() + + @property + def num_nodes(self): + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int): + self._num_nodes = num_nodes self.set_world_ranks() @property @@ -84,10 +96,6 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__ = state - @property - def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 - @property def root_device(self): return self.parallel_devices[self.local_rank] diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 2511891ccaa18..08caa7398ab8c 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -31,10 +31,6 @@ class DataParallelPlugin(ParallelPlugin): def __init__(self, parallel_devices: Optional[List[torch.device]]): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) - @property - def is_cluster_environment_resettable(self): - return False - @property def global_rank(self) -> int: return 0 diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index cddb52229a717..99899aed11753 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -34,10 +34,6 @@ def __init__(self, parallel_devices: Optional[List[torch.device]] = None): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) rank_zero_only.rank = self.global_rank - @property - def is_cluster_environment_resettable(self): - return False - @property def global_rank(self) -> int: return hvd.rank() diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 1010099759fce..696e9695f2200 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -35,26 +35,14 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, ): super().__init__() - self._parallel_devices = parallel_devices - self._cluster_environment = cluster_environment + self.parallel_devices = parallel_devices + self.cluster_environment = cluster_environment @property @abstractmethod def root_device(self): raise NotImplementedError - @property - def parallel_devices(self): - return self._parallel_devices - - @property - def cluster_environment(self): - return self._cluster_environment - - @property - def is_cluster_environment_resettable(self): - return True - @property def on_gpu(self): return self.root_device.type == "cuda" and torch.cuda.is_available() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d102ddd20c76e..902471ea55f51 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -53,10 +53,6 @@ def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[ self.tpu_global_core_rank = 0 self.start_method = None - @property - def is_cluster_environment_resettable(self): - return False - @property def global_rank(self) -> int: return self.tpu_local_core_rank diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index bb7b090bb2f5e..dc7e175d6adca 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -427,20 +427,19 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: # necessary for when the user has passed in a plugin - if hasattr(training_type, "parallel_devices") and not getattr(training_type, "parallel_devices"): + if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): training_type.parallel_devices = self.parallel_devices + if hasattr(training_type, 'num_processes'): + training_type.num_processes = len(self.parallel_devices) - if ( - hasattr(training_type, "cluster_environment") and getattr(training_type, "cluster_environment") is None - and training_type.is_cluster_environment_resettable - ): + if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: training_type.cluster_environment = self.select_cluster_environment() - if hasattr(training_type, "num_nodes"): + if hasattr(training_type, 'num_nodes'): # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes - if hasattr(training_type, "sync_batchnorm"): + if hasattr(training_type, 'sync_batchnorm'): # Set sync_batchnorm for training_type from trainer setting. training_type.sync_batchnorm = self.sync_batchnorm diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index 48367fcb8078c..8c720e0c4d990 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -65,9 +65,12 @@ def test_ranks_available_manual_plugin_selection(plugin_cls): expected.update(global_rank=expected["node_rank"], world_size=num_nodes) with mock.patch.dict(os.environ, variables): - plugin = plugin_cls(parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)]) + plugin = plugin_cls( + parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], + cluster_environment=cluster, + ) trainer = Trainer( - plugins=[cluster, plugin], + plugins=[plugin], num_nodes=num_nodes, ) assert rank_zero_only.rank == expected["global_rank"] From bdb66ab8921f975c30827a992ab4870c2358c9cf Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 15 Apr 2021 12:14:50 -0700 Subject: [PATCH 34/52] fix test --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index dc7e175d6adca..0ede3a845b1f2 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -435,11 +435,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: training_type.cluster_environment = self.select_cluster_environment() - if hasattr(training_type, 'num_nodes'): + if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') != self.num_nodes: # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes - if hasattr(training_type, 'sync_batchnorm'): + if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') != self.sync_batchnorm: # Set sync_batchnorm for training_type from trainer setting. training_type.sync_batchnorm = self.sync_batchnorm From 552f445c84ad38a87d840c916dabf936ebc37d2d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 15 Apr 2021 16:04:04 -0700 Subject: [PATCH 35/52] add changlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 277fee3463e22..87ece5bb931ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -125,6 +125,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) +- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) + + ### Deprecated - Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) From 1655f1e80c2a53d6623a285f1336cc9e5848c890 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 15 Apr 2021 17:14:34 -0700 Subject: [PATCH 36/52] retrigger checkes From ad77ad413b33f2443065144eb1345389beb468e0 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 20 Apr 2021 12:55:21 -0700 Subject: [PATCH 37/52] comments --- pytorch_lightning/plugins/training_type/ddp.py | 17 +++++++++-------- pytorch_lightning/plugins/training_type/ddp2.py | 9 +++++---- .../plugins/training_type/ddp_spawn.py | 13 +++++++------ .../trainer/connectors/accelerator_connector.py | 6 +++--- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 489c7140fed09..27cf69a7a44aa 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -82,8 +82,6 @@ def __init__( self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper - # note that world ranks is related to num_nodes, when resetting these parameters, - # need to reset world ranks self.set_world_ranks() @property @@ -91,11 +89,13 @@ def root_device(self): return self.parallel_devices[self.local_rank] @property - def num_nodes(self): + def num_nodes(self) -> int: return self._num_nodes @num_nodes.setter - def num_nodes(self, num_nodes: int): + def num_nodes(self, num_nodes: int) -> None: + # note that world ranks is related to num_nodes, when resetting these parameters, + # need to reset world ranks self._num_nodes = num_nodes self.set_world_ranks() @@ -225,10 +225,11 @@ def _check_can_spawn_children(self): ) def set_world_ranks(self) -> None: - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index 3c9773b3b5fb1..b6d21904d1933 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -72,7 +72,8 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return False - def set_world_ranks(self): - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank) - self.cluster_environment.set_world_size(self.num_nodes) + def set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank) + self.cluster_environment.set_world_size(self.num_nodes) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a1a37f8f28038..d2edefab65092 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -75,11 +75,11 @@ def __init__( self.set_world_ranks() @property - def num_nodes(self): + def num_nodes(self) -> int: return self._num_nodes @num_nodes.setter - def num_nodes(self, num_nodes: int): + def num_nodes(self, num_nodes: int) -> None: self._num_nodes = num_nodes self.set_world_ranks() @@ -117,10 +117,11 @@ def setup(self, model): def set_world_ranks(self, process_idx: int = 0) -> None: self._local_rank = process_idx - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() @property def mp_spawn_kwargs(self): diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0ede3a845b1f2..bf382e450b302 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -435,12 +435,12 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: training_type.cluster_environment = self.select_cluster_environment() - if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') != self.num_nodes: + if hasattr(training_type, 'num_nodes'): # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes - if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') != self.sync_batchnorm: - # Set sync_batchnorm for training_type from trainer setting. + if hasattr(training_type, 'sync_batchnorm'): + # set sync_batchnorm for training_type from trainer setting training_type.sync_batchnorm = self.sync_batchnorm return training_type From 77ef90a238692d018eb5d3c692cd61577026745d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 20 Apr 2021 16:26:35 -0700 Subject: [PATCH 38/52] change accelerator_connector training_type_plugin to resolve only once --- CHANGELOG.md | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 +--- .../trainer/connectors/accelerator_connector.py | 8 +++++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62a7b1e566f14..a45f8791058f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,7 +138,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled `lr_scheduler.step()` in manual optimization ([#6825](https://github.com/PyTorchLightning/pytorch-lightning/pull/6825)) -- Changed warnings and recommendations for dastaloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) +- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) - Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 4b2c69980ae55..60a986885e974 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -70,8 +70,6 @@ def __init__( self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._local_rank = 0 - # note that world ranks is related to num_nodes, when resetting these parameters, - # need to reset world ranks self.set_world_ranks() @property diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c6dc1a8cb244b..365e0befef9fb 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -123,6 +123,7 @@ def __init__( self.handle_given_plugins(plugins) + self._training_type_plugin_resolved = False self.accelerator = self.select_accelerator() # override dist backend when using tpus @@ -210,6 +211,7 @@ def handle_given_plugins( ) self._training_type_plugin = training_type + print(f"self._training_type_plugin is {self._training_type_plugin}") self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment() @@ -221,10 +223,14 @@ def precision_plugin(self) -> PrecisionPlugin: @property def training_type_plugin(self) -> TrainingTypePlugin: + if self._training_type_plugin_resolved: + # avoid calling `resolve_training_type_plugin` multiple times + return self._training_type_plugin if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() else: self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + self._training_type_plugin_resolved = True return self._training_type_plugin @@ -437,7 +443,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: # necessary for when the user has passed in a plugin - if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + if hasattr(training_type, 'parallel_devices') and getattr(training_type, 'parallel_devices') is None: training_type.parallel_devices = self.parallel_devices if hasattr(training_type, 'num_processes'): training_type.num_processes = len(self.parallel_devices) From 36427ca89dd774a1b8c17eb32628304e25c3a871 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 20 Apr 2021 16:30:00 -0700 Subject: [PATCH 39/52] nits --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 +++- pytorch_lightning/trainer/connectors/accelerator_connector.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 60a986885e974..3627f57a0d54f 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -78,6 +78,8 @@ def num_nodes(self) -> int: @num_nodes.setter def num_nodes(self, num_nodes: int) -> None: + # note that world ranks is related to num_nodes, when resetting these parameters, + # need to reset world ranks self._num_nodes = num_nodes self.set_world_ranks() diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 365e0befef9fb..306fe1ebdd305 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -211,7 +211,6 @@ def handle_given_plugins( ) self._training_type_plugin = training_type - print(f"self._training_type_plugin is {self._training_type_plugin}") self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment() From 824fb2565ea1e66511f5b3f5604fd176da267d60 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 20 Apr 2021 21:09:14 -0700 Subject: [PATCH 40/52] make num_nodes and sync_batchnorm as optional argument for plugin and use resolve_training_type_plugin to reset from trainer params --- .../plugins/training_type/ddp.py | 24 +++++++++++++++--- .../plugins/training_type/ddp_spawn.py | 25 ++++++++++++++++--- .../plugins/training_type/deepspeed.py | 2 +- .../plugins/training_type/rpc.py | 2 +- .../connectors/accelerator_connector.py | 11 ++------ 5 files changed, 45 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 8578356800518..df37d3aa0e01b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -62,9 +62,9 @@ class DDPPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, + num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: bool = False, + sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -72,8 +72,16 @@ def __init__( ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] - self._num_nodes = num_nodes - self.sync_batchnorm = sync_batchnorm + if num_nodes is not None: + rank_zero_warn( + "`num_nodes` passed in DDPPlugin constructor, but notice that it will be overriden by the trainer setting." + ) + self._num_nodes = num_nodes or 1 + if sync_batchnorm is not None: + rank_zero_warn( + "`sync_batchnorm` passed in DDPPlugin constructor, but notice that it will be overriden by the trainer setting." + ) + self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs @@ -99,6 +107,14 @@ def num_nodes(self, num_nodes: int) -> None: self._num_nodes = num_nodes self.set_world_ranks() + @property + def sync_batchnorm(self) -> bool: + return self._sync_batchnorm + + @sync_batchnorm.setter + def sync_batchnorm(self, sync_batchnorm: bool) -> None: + self._sync_batchnorm = sync_batchnorm + @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3627f57a0d54f..0a9a226ed696e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -51,17 +51,25 @@ class DDPSpawnPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, + num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: bool = False, + sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Any, ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) - self._num_nodes = num_nodes - self.sync_batchnorm = sync_batchnorm + if num_nodes is not None: + rank_zero_warn( + "`num_nodes` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by the trainer setting." + ) + self._num_nodes = num_nodes or 1 + if sync_batchnorm is not None: + rank_zero_warn( + "`sync_batchnorm` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by the trainer setting." + ) + self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 @@ -83,6 +91,14 @@ def num_nodes(self, num_nodes: int) -> None: self._num_nodes = num_nodes self.set_world_ranks() + @property + def sync_batchnorm(self) -> bool: + return self._sync_batchnorm + + @sync_batchnorm.setter + def sync_batchnorm(self, sync_batchnorm: bool) -> None: + self._sync_batchnorm = sync_batchnorm + @property def local_rank(self) -> int: return self._local_rank @@ -177,6 +193,7 @@ def new_process(self, process_idx, trainer, mp_queue): # move the model to the correct device self.model_to_device() + assert self.sync_batchnorm is not None if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 34a9f504082e1..ab55ddf17ea76 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -90,7 +90,7 @@ def __init__( zero_allow_untested_optimizer: bool = True, config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, - num_nodes: int = 1, + num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index c2c2f3c257304..3e0f57daef001 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -42,7 +42,7 @@ def __init__( self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, + num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, sync_batchnorm: Optional[bool] = None, **kwargs diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 306fe1ebdd305..7f69e62b38634 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -227,8 +227,7 @@ def training_type_plugin(self) -> TrainingTypePlugin: return self._training_type_plugin if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() - else: - self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) self._training_type_plugin_resolved = True return self._training_type_plugin @@ -384,15 +383,11 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: plugin = DDP2Plugin( parallel_devices=self.parallel_devices, - num_nodes=self.num_nodes, cluster_environment=self.cluster_environment, - sync_batchnorm=self.sync_batchnorm, ) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( - num_nodes=self.num_nodes, - cluster_environment=self.select_cluster_environment(), - parallel_devices=self.parallel_devices + cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks @@ -425,9 +420,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, - num_nodes=self.num_nodes, cluster_environment=self.cluster_environment, - sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) From 66fab62c048c53efa9a6e69474612b8e0252b1ec Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 20 Apr 2021 22:09:01 -0700 Subject: [PATCH 41/52] format --- pytorch_lightning/plugins/training_type/ddp.py | 9 +++++---- pytorch_lightning/plugins/training_type/ddp_spawn.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index df37d3aa0e01b..7cc42f0d4ac36 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -74,12 +74,14 @@ def __init__( self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_warn( - "`num_nodes` passed in DDPPlugin constructor, but notice that it will be overriden by the trainer setting." + "`num_nodes` passed in DDPPlugin constructor, but notice that it will be overriden by " + "the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_warn( - "`sync_batchnorm` passed in DDPPlugin constructor, but notice that it will be overriden by the trainer setting." + "`sync_batchnorm` passed in DDPPlugin constructor, but notice that it will be overriden by " + "the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() @@ -102,8 +104,7 @@ def num_nodes(self) -> int: @num_nodes.setter def num_nodes(self, num_nodes: int) -> None: - # note that world ranks is related to num_nodes, when resetting these parameters, - # need to reset world ranks + # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks self._num_nodes = num_nodes self.set_world_ranks() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0a9a226ed696e..580db3ee67f9b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -62,12 +62,14 @@ def __init__( super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) if num_nodes is not None: rank_zero_warn( - "`num_nodes` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by the trainer setting." + "`num_nodes` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by " + "the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_warn( - "`sync_batchnorm` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by the trainer setting." + "`sync_batchnorm` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by " + "the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs @@ -86,8 +88,7 @@ def num_nodes(self) -> int: @num_nodes.setter def num_nodes(self, num_nodes: int) -> None: - # note that world ranks is related to num_nodes, when resetting these parameters, - # need to reset world ranks + # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks self._num_nodes = num_nodes self.set_world_ranks() From 63e4a4e1d29bbcc12a285e72bf96bd47d3bec4ad Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 21 Apr 2021 10:46:25 -0700 Subject: [PATCH 42/52] change warn to deprecation --- pytorch_lightning/plugins/training_type/ddp.py | 11 ++++++----- pytorch_lightning/plugins/training_type/ddp_spawn.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7cc42f0d4ac36..ec310cf0565f1 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -33,6 +33,7 @@ _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, + rank_zero_deprecation, rank_zero_warn, ) from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available @@ -73,15 +74,15 @@ def __init__( super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] if num_nodes is not None: - rank_zero_warn( - "`num_nodes` passed in DDPPlugin constructor, but notice that it will be overriden by " - "the trainer setting." + rank_zero_deprecation( + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Notice that it will be overriden by the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_warn( - "`sync_batchnorm` passed in DDPPlugin constructor, but notice that it will be overriden by " - "the trainer setting." + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 580db3ee67f9b..611ab1dc9ba22 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -61,15 +61,15 @@ def __init__( ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) if num_nodes is not None: - rank_zero_warn( - "`num_nodes` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by " - "the trainer setting." + rank_zero_deprecation( + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Notice that it will be overriden by the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_warn( - "`sync_batchnorm` passed in DDPSpawnPlugin constructor, but notice that it will be overriden by " - "the trainer setting." + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs From 2b8c772be231fce7e6a7031b2611f66d231f2dda Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 21 Apr 2021 13:44:54 -0700 Subject: [PATCH 43/52] fix --- pytorch_lightning/plugins/training_type/ddp.py | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index ec310cf0565f1..3c5b1fed06a26 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -80,7 +80,7 @@ def __init__( ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: - rank_zero_warn( + rank_zero_deprecation( "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " "Notice that it will be overriden by the trainer setting." ) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 611ab1dc9ba22..2244016740f37 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -31,7 +31,13 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + rank_zero_only, + rank_zero_warn, + rank_zero_deprecation, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.seed import seed_everything if _TORCH_GREATER_EQUAL_1_8: @@ -67,7 +73,7 @@ def __init__( ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: - rank_zero_warn( + rank_zero_deprecation( "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " "Notice that it will be overriden by the trainer setting." ) @@ -128,6 +134,7 @@ def _is_single_process_single_device(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() From 76016de6c1db8f2aafadd1d1636bb8d0cb7eb72c Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 22 Apr 2021 23:51:18 -0700 Subject: [PATCH 44/52] minor --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5ed15dfbcad53..8d9c5a400401e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -134,7 +134,6 @@ def _is_single_process_single_device(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() From 0996a5d89d99febd93e732369ace22eb9f08b112 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Fri, 23 Apr 2021 00:02:54 -0700 Subject: [PATCH 45/52] remove unnecessary assert --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8d9c5a400401e..d10fdf8ac36c3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -200,7 +200,6 @@ def new_process(self, process_idx, trainer, mp_queue): # move the model to the correct device self.model_to_device() - assert self.sync_batchnorm is not None if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) From 16858be2f71682f154c6c3abbc7bea3d79c41431 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 4 May 2021 03:27:13 -0700 Subject: [PATCH 46/52] comments --- .../plugins/training_type/ddp.py | 4 +- .../plugins/training_type/ddp_spawn.py | 4 +- tests/deprecated_api/test_remove_1-6.py | 52 +++++++++++++++++++ 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 9ea10ba352655..6e969a6f585fd 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -75,13 +75,13 @@ def __init__( self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( - "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " "Notice that it will be overriden by the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_deprecation( - "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " "Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 332ede71d0bd0..6a43914ded002 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -68,13 +68,13 @@ def __init__( super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) if num_nodes is not None: rank_zero_deprecation( - "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " "Notice that it will be overriden by the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_deprecation( - "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.3, and will be removed in v1.5. " + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " "Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index e69de29bb2d1d..a12c2690a9b21 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.6.0""" +import pytest + + +from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin + + + +def test_v1_6_0_ddp_num_nodes(): + with pytest.deprecated_call( + match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4" + ): + ddp_plugin = DDPPlugin( + num_nodes=1, + ) + +def test_v1_6_0_ddp_sync_batchnorm(): + with pytest.deprecated_call( + match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4" + ): + ddp_plugin = DDPPlugin( + sync_batchnorm=False, + ) + +def test_v1_6_0_ddp_spawn_num_nodes(): + with pytest.deprecated_call( + match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4" + ): + ddp_plugin = DDPSpawnPlugin( + num_nodes=1, + ) + +def test_v1_6_0_ddp_spawn_sync_batchnorm(): + with pytest.deprecated_call( + match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4" + ): + ddp_plugin = DDPSpawnPlugin( + sync_batchnorm=False, + ) From 60580be6f51b1ae7d7cdedf91ae0a851520ca6c7 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 4 May 2021 03:31:50 -0700 Subject: [PATCH 47/52] remove extra in change.md --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cc86105fdf21..0418459765c31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,7 +182,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) -- `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) - Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937)) From 20d59a40a2de40cdb21f9de261728b794b700d55 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 4 May 2021 03:32:55 -0700 Subject: [PATCH 48/52] correct in change.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0418459765c31..f9c2b37188395 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -181,7 +181,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) -- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) +- `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) - Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937)) From 0ab7147201f3504f046b6ec7f7bfda5d6f7a4ae0 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 4 May 2021 04:05:57 -0700 Subject: [PATCH 49/52] fix test and flake8 --- .../plugins/training_type/ddp_spawn.py | 2 +- tests/deprecated_api/test_remove_1-6.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 6a43914ded002..5cd6a61050327 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -38,7 +38,7 @@ ReduceOp, sync_ddp_if_available, ) -from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.seed import reset_seed if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index a12c2690a9b21..96d4208c4dd4b 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -13,40 +13,40 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.6.0""" import pytest - - from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin - def test_v1_6_0_ddp_num_nodes(): with pytest.deprecated_call( match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4" ): - ddp_plugin = DDPPlugin( + DDPPlugin( num_nodes=1, ) + def test_v1_6_0_ddp_sync_batchnorm(): with pytest.deprecated_call( match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4" ): - ddp_plugin = DDPPlugin( + DDPPlugin( sync_batchnorm=False, ) + def test_v1_6_0_ddp_spawn_num_nodes(): with pytest.deprecated_call( match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4" ): - ddp_plugin = DDPSpawnPlugin( + DDPSpawnPlugin( num_nodes=1, ) + def test_v1_6_0_ddp_spawn_sync_batchnorm(): with pytest.deprecated_call( match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4" ): - ddp_plugin = DDPSpawnPlugin( + DDPSpawnPlugin( sync_batchnorm=False, ) From 9fdde9414bff76ef4dad7c48ee6931e6aec02b83 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 4 May 2021 22:18:13 +0200 Subject: [PATCH 50/52] pre-commit --- .../plugins/training_type/ddp_spawn.py | 2 +- tests/deprecated_api/test_remove_1-6.py | 33 +++++-------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5e8e3ae263201..df9f0ee158ba3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -32,9 +32,9 @@ from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import ( + rank_zero_deprecation, rank_zero_only, rank_zero_warn, - rank_zero_deprecation, ReduceOp, sync_ddp_if_available, ) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 96d4208c4dd4b..7d3a766b001b1 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -13,40 +13,25 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.6.0""" import pytest + from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin def test_v1_6_0_ddp_num_nodes(): - with pytest.deprecated_call( - match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4" - ): - DDPPlugin( - num_nodes=1, - ) + with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): + DDPPlugin(num_nodes=1) def test_v1_6_0_ddp_sync_batchnorm(): - with pytest.deprecated_call( - match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4" - ): - DDPPlugin( - sync_batchnorm=False, - ) + with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"): + DDPPlugin(sync_batchnorm=False) def test_v1_6_0_ddp_spawn_num_nodes(): - with pytest.deprecated_call( - match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4" - ): - DDPSpawnPlugin( - num_nodes=1, - ) + with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): + DDPSpawnPlugin(num_nodes=1) def test_v1_6_0_ddp_spawn_sync_batchnorm(): - with pytest.deprecated_call( - match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4" - ): - DDPSpawnPlugin( - sync_batchnorm=False, - ) + with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"): + DDPSpawnPlugin(sync_batchnorm=False) From 621bfc81b8a218f04522b2dfdaab92fad8208602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 8 May 2021 12:58:40 +0200 Subject: [PATCH 51/52] whitespace standardization --- pytorch_lightning/plugins/training_type/ddp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index b07307c559c46..89e714d57f870 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -75,14 +75,14 @@ def __init__( self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( - "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " - "Notice that it will be overriden by the trainer setting." + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." + " Notice that it will be overriden by the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_deprecation( - "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " - "Notice that it will be overriden by the trainer setting." + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." + " Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() From 29f720bdd3995fb0cbd78d5ad5374d7a5333e4ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 May 2021 10:58:59 +0000 Subject: [PATCH 52/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-6.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index fe8ebcabf4e2b..6949175d7df14 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -30,7 +30,7 @@ def test_v1_6_0_trainer_model_hook_mixin(tmpdir): with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"): trainer.has_arg("training_step", "batch") - + def test_v1_6_0_ddp_num_nodes(): with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): DDPPlugin(num_nodes=1)