Skip to content

Commit 4610fdd

Browse files
ananthsubawaelchli
andauthored
Mark Trainer.terminate_on_nan protected and deprecate public property (#9849)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent dd6d797 commit 4610fdd

File tree

5 files changed

+33
-10
lines changed

5 files changed

+33
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
284284
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
285285

286286

287+
- Deprecated `Trainer.terminate_on_nan` public attribute access ([#9849](https://github.com/PyTorchLightning/pytorch-lightning/pull/9849))
288+
289+
287290
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
288291

289292

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def backward_fn(loss: Tensor) -> None:
344344
self._backward(loss, optimizer, opt_idx)
345345

346346
# check if model weights are nan
347-
if self.trainer.terminate_on_nan:
347+
if self.trainer._terminate_on_nan:
348348
detect_nan_parameters(self.trainer.lightning_module)
349349

350350
return backward_fn
@@ -460,7 +460,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
460460

461461
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
462462

463-
if self.trainer.terminate_on_nan:
463+
if self.trainer._terminate_on_nan:
464464
check_finite_loss(result.closure_loss)
465465

466466
if self.trainer.move_metrics_to_cpu:

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def on_trainer_init(
5252
f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}."
5353
)
5454

55-
self.trainer.terminate_on_nan = terminate_on_nan
55+
self.trainer._terminate_on_nan = terminate_on_nan
5656
self.trainer.gradient_clip_val = gradient_clip_val
5757
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
5858
self.trainer.track_grad_norm = float(track_grad_norm)

pytorch_lightning/trainer/trainer.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,13 +1999,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
19991999
if self.predicting:
20002000
return self.predict_loop
20012001

2002-
@property
2003-
def train_loop(self) -> FitLoop:
2004-
rank_zero_deprecation(
2005-
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
2006-
)
2007-
return self.fit_loop
2008-
20092002
@property
20102003
def _ckpt_path(self) -> Optional[str]:
20112004
if self.state.fn == TrainerFn.VALIDATING:
@@ -2055,3 +2048,23 @@ def __getstate__(self):
20552048

20562049
def __setstate__(self, state):
20572050
self.__dict__ = state
2051+
2052+
@property
2053+
def train_loop(self) -> FitLoop:
2054+
rank_zero_deprecation(
2055+
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
2056+
)
2057+
return self.fit_loop
2058+
2059+
@property
2060+
def terminate_on_nan(self) -> bool:
2061+
rank_zero_deprecation("`Trainer.terminate_on_nan` is deprecated in v1.5 and will be removed in 1.7.")
2062+
return self._terminate_on_nan
2063+
2064+
@terminate_on_nan.setter
2065+
def terminate_on_nan(self, val: bool) -> None:
2066+
rank_zero_deprecation(
2067+
f"Setting `Trainer.terminate_on_nan = {val}` is deprecated in v1.5 and will be removed in 1.7."
2068+
f" Please set `Trainer(detect_anomaly={val})` instead."
2069+
)
2070+
self._terminate_on_nan = val # : 212

tests/deprecated_api/test_remove_1-7.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
131131
assert trainer.terminate_on_nan is terminate_on_nan
132132
assert trainer._detect_anomaly is False
133133

134+
trainer = Trainer()
135+
with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"):
136+
_ = trainer.terminate_on_nan
137+
138+
with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"):
139+
trainer.terminate_on_nan = True
140+
134141

135142
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
136143
class CustomBoringModel(BoringModel):

0 commit comments

Comments
 (0)