Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
stochastic_weight_avg: bool = False,
detect_anomaly: bool = False,
):
r"""
Customize every aspect of training via flags.
Expand Down Expand Up @@ -223,6 +224,8 @@ def __init__(
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

detect_anomaly: Enable anomaly detection for the autograd engine.

deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Default: ``False``.

Expand Down Expand Up @@ -488,6 +491,7 @@ def __init__(
track_grad_norm,
terminate_on_nan,
)
self._detect_anomaly: bool = detect_anomaly
self._setup_on_init(num_sanity_val_steps)

# configure tuner
Expand Down Expand Up @@ -1184,7 +1188,8 @@ def _run_train(self) -> None:
self.reset_train_val_dataloaders(model)

self.fit_loop.trainer = self
self.fit_loop.run()
with torch.autograd.set_detect_anomaly(self._detect_anomaly):
self.fit_loop.run()

def _run_evaluate(self) -> _EVALUATE_OUTPUT:
if not self.is_global_zero and self.progress_bar_callback is not None:
Expand Down
16 changes: 16 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,3 +2079,19 @@ def test_step(self, *args, **kwargs):
trainer.validate(model)
trainer.test(model)
trainer.predict(model)


def test_detect_anomaly_nan(tmpdir):
class NanModel(BoringModel):
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
output["loss"] = output["loss"] * torch.tensor(float("nan"))
return output

model = NanModel()
trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True)
with pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."):
with pytest.warns(
UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*"
):
trainer.fit(model)