From 9ed3f7580cb8c0de6e5e494a137b322b2a8e07a1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 15 Oct 2021 17:13:00 +0200 Subject: [PATCH 1/4] Avoid deprecation warning after #9901 --- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index be0a7728edddc..e3385b18ac6ff 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1401,14 +1401,14 @@ def call_hook( if callable(model_fx): output = model_fx(*args, **kwargs) - # call the accelerator hook - if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name): - accelerator_hook = getattr(self.accelerator, hook_name) - accelerator_output = accelerator_hook(*args, **kwargs) - # Rely on the accelerator output if lightningModule hook returns nothing + # call the ttp hook + if hook_name not in ("setup", "teardown") and hasattr(self.training_type_plugin, hook_name): + ttp_hook = getattr(self.training_type_plugin, hook_name) + ttp_output = ttp_hook(*args, **kwargs) + # Rely on the TTP output if lightningModule hook returns nothing # Required for cases such as DataParallel where we reduce the output for the user # todo: move this data parallel logic into the data parallel plugin - output = accelerator_output if output is None else output + output = ttp_output if output is None else output if pl_module: # restore current_fx when nested context From 472131f20101bd92479b661e01891c6fe11779a8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 15 Oct 2021 20:04:30 +0200 Subject: [PATCH 2/4] Have signatures match --- pytorch_lightning/plugins/training_type/ipu.py | 2 +- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- tests/loops/test_training_loop.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index daa704e8a8243..b6728b0551081 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -285,7 +285,7 @@ def on_test_end(self): def on_predict_end(self): self._detach_models() - def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: # Updates optimizer stats if LR scheduler modified the optimizer state optimizer = self.lightning_module.trainer.optimizers[0] self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf36a3502702d..9c53069063a52 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -345,7 +345,7 @@ def on_predict_end(self): """Called when predict ends.""" pass - def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Called in the training loop before anything happens for that batch.""" pass diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index d491db3bbc91c..ebfe0d4762806 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -86,7 +86,7 @@ def run_training(**trainer_kwargs): @pytest.mark.parametrize(["max_epochs", "batch_idx_"], [(2, 5), (3, 8), (4, 12)]) def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_, tmpdir): class CurrentModel(BoringModel): - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): if batch_idx == batch_idx_: return -1 From 14e61d8c57b18c5a143d9ec84fc0a526bb9edc73 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 15 Oct 2021 22:14:44 +0200 Subject: [PATCH 3/4] Address review --- pytorch_lightning/accelerators/gpu.py | 1 + pytorch_lightning/trainer/trainer.py | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index b33903c2d60c9..44b29efe6f2bc 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -46,6 +46,7 @@ def setup(self, trainer: "pl.Trainer") -> None: return super().setup(trainer) def on_train_start(self) -> None: + super().on_train_start() # clear cache before training torch.cuda.empty_cache() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e3385b18ac6ff..3ad8fba07eae6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1401,13 +1401,27 @@ def call_hook( if callable(model_fx): output = model_fx(*args, **kwargs) + # *Bad code alert* + # The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated. + # The following logic selectively chooses which hooks are called on each object. + # In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the + # same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle. + # All of this should be fixed by #8506 + + # call the accelerator hook + if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name): + accelerator_hook = getattr(self.accelerator, hook_name) + accelerator_output = accelerator_hook(*args, **kwargs) + # Required for cases such as DataParallel where we reduce the output for the user + # todo: move this data parallel logic into the data parallel plugin + output = accelerator_output if output is None else output + # call the ttp hook - if hook_name not in ("setup", "teardown") and hasattr(self.training_type_plugin, hook_name): + if hook_name not in ("setup", "teardown", "on_train_start") and hasattr( + self.training_type_plugin, hook_name + ): ttp_hook = getattr(self.training_type_plugin, hook_name) ttp_output = ttp_hook(*args, **kwargs) - # Rely on the TTP output if lightningModule hook returns nothing - # Required for cases such as DataParallel where we reduce the output for the user - # todo: move this data parallel logic into the data parallel plugin output = ttp_output if output is None else output if pl_module: From 1a7773e84c25dca5eb3565326bc9cc3d132a5ab1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 15 Oct 2021 22:15:33 +0200 Subject: [PATCH 4/4] Add back comment --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3ad8fba07eae6..e6d8ccde91d71 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1412,6 +1412,7 @@ def call_hook( if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) accelerator_output = accelerator_hook(*args, **kwargs) + # Rely on the accelerator output if lightningModule hook returns nothing # Required for cases such as DataParallel where we reduce the output for the user # todo: move this data parallel logic into the data parallel plugin output = accelerator_output if output is None else output