Skip to content

Commit 173f4c8

Browse files
yopknopixxcarmoccaawaelchli
authored
Deprecate terminate_on_nan Trainer argument in favor of detect_anomaly (#9175)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6a0c47a commit 173f4c8

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
271271

272272
### Deprecated
273273

274+
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
275+
276+
274277
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
275278

276279

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Union
14+
from typing import Optional, Union
1515

16-
from pytorch_lightning.utilities import GradClipAlgorithmType
16+
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
1717
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1818

1919

@@ -26,10 +26,15 @@ def on_trainer_init(
2626
gradient_clip_val: Union[int, float],
2727
gradient_clip_algorithm: str,
2828
track_grad_norm: Union[int, float, str],
29-
terminate_on_nan: bool,
29+
terminate_on_nan: Optional[bool],
3030
):
31-
if not isinstance(terminate_on_nan, bool):
32-
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
31+
if terminate_on_nan is not None:
32+
rank_zero_deprecation(
33+
"Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7."
34+
" Please use `Trainer(detect_anomaly=True)` instead."
35+
)
36+
if not isinstance(terminate_on_nan, bool):
37+
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
3338

3439
# gradient clipping
3540
if not isinstance(gradient_clip_val, (int, float)):

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
reload_dataloaders_every_epoch: bool = False,
168168
auto_lr_find: Union[bool, str] = False,
169169
replace_sampler_ddp: bool = True,
170-
terminate_on_nan: bool = False,
170+
detect_anomaly: bool = False,
171171
auto_scale_batch_size: Union[str, bool] = False,
172172
prepare_data_per_node: Optional[bool] = None,
173173
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
@@ -177,7 +177,7 @@ def __init__(
177177
move_metrics_to_cpu: bool = False,
178178
multiple_trainloader_mode: str = "max_size_cycle",
179179
stochastic_weight_avg: bool = False,
180-
detect_anomaly: bool = False,
180+
terminate_on_nan: Optional[bool] = None,
181181
):
182182
r"""
183183
Customize every aspect of training via flags.
@@ -351,6 +351,12 @@ def __init__(
351351
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
352352
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
353353
354+
.. deprecated:: v1.5
355+
Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7.
356+
Please use ``detect_anomaly`` instead.
357+
358+
detect_anomaly: Enable anomaly detection for the autograd engine.
359+
354360
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
355361
356362
ipus: How many IPUs to train on.

tests/deprecated_api/test_remove_1-7.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir):
122122
_ = Trainer(stochastic_weight_avg=True)
123123

124124

125+
@pytest.mark.parametrize("terminate_on_nan", [True, False])
126+
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
127+
with pytest.deprecated_call(
128+
match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7"
129+
):
130+
trainer = Trainer(terminate_on_nan=terminate_on_nan)
131+
assert trainer.terminate_on_nan is terminate_on_nan
132+
assert trainer._detect_anomaly is False
133+
134+
125135
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
126136
class CustomBoringModel(BoringModel):
127137
def on_train_dataloader(self):

0 commit comments

Comments
 (0)