diff --git a/CHANGELOG.md b/CHANGELOG.md index cc28cb080eb23..f9a521eed5815 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,6 +113,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) +- The trainer no longer tries to save a checkpoint on exception or run callback's `on_train_end` functions ([#6864](https://github.com/PyTorchLightning/pytorch-lightning/pull/6864)) + + - Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) @@ -258,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950)) +- Fixed bug for trainer error handling which would cause hang for distributed training ([#6864](https://github.com/PyTorchLightning/pytorch-lightning/pull/6864)) + + - Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 96e19a7be4694..0a93baeee98b5 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -144,8 +144,8 @@ So you can run it like so: .. note:: If you want to stop a training run early, you can press "Ctrl + C" on your keyboard. The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown, including - running callbacks such as ``on_train_end``. The trainer object will also set an attribute - ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute + running accelerator callback ``on_train_end`` to clean up memory. The trainer object will also set + an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs. ------------ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 88a313faee395..6e4948fd6d019 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -16,7 +16,6 @@ import warnings from itertools import count from pathlib import Path -from traceback import print_exc from typing import Any, Dict, Iterable, List, Optional, Union import torch @@ -420,7 +419,7 @@ def fit( # we reuse fit for other functions. When already set, it shouldn't be modified. if not self.state.running: self.state = TrainerState.FITTING - if self._running_stage is None: + if self._running_stage is None or self.tuning: self.training = True # set local properties on the model @@ -607,6 +606,7 @@ def run_train(self) -> None: self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: + self.train_loop.on_train_end() return # early stopping @@ -615,6 +615,7 @@ def run_train(self) -> None: if self.should_stop: if met_min_epochs and met_min_steps: + self.train_loop.on_train_end() return else: log.info( @@ -633,14 +634,15 @@ def run_train(self) -> None: if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() - except (RuntimeError, AssertionError): - # if an exception is raised, the finally block is executed and can hide the actual exception - # that was initially raised if `on_train_end` also raises an exception. we want to avoid that - # for assertions and other runtime errors so we aren't misled while debugging - print_exc() - finally: - # hook - self.train_loop.on_train_end() + # same treatment as below + self.accelerator.on_train_end() + self._running_stage = None + except BaseException: + # give accelerators a chance to finish + self.accelerator.on_train_end() + # reset bookkeeping + self._running_stage = None + raise def run_evaluation(self, on_epoch=False): if not (self.evaluating or self.sanity_checking): diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index d1bcee43b1f02..7c5a6c03766dc 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -35,8 +35,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert 'x' in outputs def on_train_epoch_end(self, trainer, pl_module, outputs): - d = outputs[0] - assert len(d) == trainer.num_training_batches + assert len(outputs) == trainer.num_training_batches class TestModel(BoringModel): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 923821a5e50e4..42105a69596bd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -171,8 +171,10 @@ def train_dataloader(self): sampler=None, ) - def training_step_end(self, *_): + def training_step_end(self, training_step_output): self.train_results = deepcopy(self.trainer.logger_connector.cached_results) + # must return + return training_step_output model = TestModel() model.training_epoch_end = None diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 70db6208164aa..4c8cf99a275f0 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -29,7 +29,7 @@ @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_multiple_optimizers_manual(tmpdir): +def test_multiple_optimizers_manual_no_return(tmpdir): """ Tests that only training_step can be used """ @@ -68,8 +68,9 @@ def training_step(self, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -279,8 +280,9 @@ def training_step(self, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -310,7 +312,7 @@ def configure_optimizers(self): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @RunIf(min_gpus=1, amp_apex=True) -def test_multiple_optimizers_manual_apex(tmpdir): +def test_multiple_optimizers_manual_apex_no_return(tmpdir): """ Tests that only training_step can be used """ @@ -353,8 +355,9 @@ def training_step(self, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -638,6 +641,8 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() + return {'loss1': loss_1, 'loss2': loss_2} + def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer assert len(outputs) == 2 @@ -724,10 +729,6 @@ def optimizer_closure(): weight_after = self.layer.weight.clone() assert not torch.equal(weight_before, weight_after) - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer @@ -788,10 +789,6 @@ def optimizer_closure(): else: assert self.layer.weight.grad is not None - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer @@ -845,10 +842,6 @@ def optimizer_closure(): opt.step(closure=optimizer_closure) opt.zero_grad() - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer @@ -923,10 +916,6 @@ def dis_closure(): opt_dis.step(closure=dis_closure) opt_dis.zero_grad() - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) @@ -1031,10 +1020,6 @@ def dis_closure(): if make_dis_optimizer_step: opt_dis.step(closure=dis_closure) - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - def configure_optimizers(self): optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 5f0ca34015df0..24b32c8725963 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -134,8 +134,9 @@ def training_step(self, batch, batch_idx): opt_b.zero_grad() def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + # outputs is empty as training_step does not return + # and it is not automatic optimization + assert len(outputs) == 0 model = TestModel() model.val_dataloader = None diff --git a/tests/trainer/test_training_loop.py b/tests/trainer/test_training_loop.py index e8d5fcd4c3b95..1349659cc4595 100644 --- a/tests/trainer/test_training_loop.py +++ b/tests/trainer/test_training_loop.py @@ -58,7 +58,17 @@ def on_after_backward(self): super().on_after_backward() def optimizer_step( - self, + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu, + using_native_amp, + using_lbfgs, + ): + super().optimizer_step( epoch, batch_idx, optimizer, @@ -67,11 +77,10 @@ def optimizer_step( on_tpu, using_native_amp, using_lbfgs, - ): - super().optimizer_step( - epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs ) - self.called.append("optimizer_step") # append after as closure calls other methods + self.called.append( + "optimizer_step" + ) # append after as closure calls other methods def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.called.append("on_train_batch_end") @@ -106,23 +115,23 @@ def on_epoch_end(self): trainer.fit(model) expected = [ - 'on_epoch_start', # validation - 'on_epoch_end', - 'on_epoch_start', # training - 'on_train_epoch_start', - 'on_train_batch_start', - 'training_step', - 'on_before_zero_grad', - 'optimizer_zero_grad', - 'backward', - 'on_after_backward', - 'optimizer_step', - 'on_train_batch_end', - 'training_epoch_end', - 'on_train_epoch_end', - 'on_epoch_end', - 'on_epoch_start', # validation - 'on_epoch_end' + "on_epoch_start", # validation + "on_epoch_end", + "on_epoch_start", # training + "on_train_epoch_start", + "on_train_batch_start", + "training_step", + "on_before_zero_grad", + "optimizer_zero_grad", + "backward", + "on_after_backward", + "optimizer_step", + "on_train_batch_end", + "training_epoch_end", + "on_train_epoch_end", + "on_epoch_end", + "on_epoch_start", # validation + "on_epoch_end", ] assert model.called == expected @@ -132,15 +141,16 @@ def test_outputs_format(tmpdir): class HookedModel(BoringModel): def training_step(self, batch, batch_idx): - self.log("foo", "bar") - return super().training_step(batch, batch_idx) + output = super().training_step(batch, batch_idx) + self.log("foo", 123) + output["foo"] = 123 + return output @staticmethod def _check_output(output): assert "loss" in output - assert "foo" in output - assert output["foo"] == "bar" + assert output["foo"] == 123 def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): HookedModel._check_output(outputs)