diff --git a/CHANGELOG.md b/CHANGELOG.md index 01076b46ec073..00bd7c84b0fb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525)) +- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa018a345262b..af225e708e343 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -177,6 +177,7 @@ def __init__( move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", stochastic_weight_avg: bool = False, + detect_anomaly: bool = False, ): r""" Customize every aspect of training via flags. @@ -223,6 +224,8 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + detect_anomaly: Enable anomaly detection for the autograd engine. + deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. Default: ``False``. @@ -488,6 +491,7 @@ def __init__( track_grad_norm, terminate_on_nan, ) + self._detect_anomaly: bool = detect_anomaly self._setup_on_init(num_sanity_val_steps) # configure tuner @@ -1184,7 +1188,8 @@ def _run_train(self) -> None: self.reset_train_val_dataloaders(model) self.fit_loop.trainer = self - self.fit_loop.run() + with torch.autograd.set_detect_anomaly(self._detect_anomaly): + self.fit_loop.run() def _run_evaluate(self) -> _EVALUATE_OUTPUT: if not self.is_global_zero and self.progress_bar_callback is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 021812464e972..acb0c10df6c63 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2079,3 +2079,19 @@ def test_step(self, *args, **kwargs): trainer.validate(model) trainer.test(model) trainer.predict(model) + + +def test_detect_anomaly_nan(tmpdir): + class NanModel(BoringModel): + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + output["loss"] = output["loss"] * torch.tensor(float("nan")) + return output + + model = NanModel() + trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) + with pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."): + with pytest.warns( + UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*" + ): + trainer.fit(model)