From b7162122922f451978ae307bad9480e158bb815a Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Sat, 28 Aug 2021 16:10:36 +0530 Subject: [PATCH 01/20] Add trainer argument for detect_anomaly. --- .../connectors/training_trick_connector.py | 2 ++ pytorch_lightning/trainer/trainer.py | 10 +++++++++- tests/trainer/test_trainer.py | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 285ed5afbf62b..2cfdabb55a26b 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -29,9 +29,11 @@ def on_trainer_init( track_grad_norm: Union[int, float, str], accumulate_grad_batches: Union[int, Dict[int, int]], terminate_on_nan: bool, + detect_anomaly: bool, ): self.trainer.terminate_on_nan = terminate_on_nan + self.trainer.detect_anomaly = detect_anomaly # gradient clipping if gradient_clip_algorithm not in list(GradClipAlgorithmType): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19ccf3935a168..a451cd3796863 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -149,6 +149,7 @@ def __init__( auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, + detect_anomaly: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: Optional[bool] = None, plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, @@ -306,6 +307,8 @@ def __init__( terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. + detect_anomaly: Enable anomaly detection for the autograd engine. + tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] ipus: How many IPUs to train on. @@ -439,6 +442,7 @@ def __init__( track_grad_norm, accumulate_grad_batches, terminate_on_nan, + detect_anomaly, ) self._setup_on_init(num_sanity_val_steps) @@ -1121,7 +1125,11 @@ def _run_train(self) -> None: self.reset_train_val_dataloaders(model) self.fit_loop.trainer = self - self.fit_loop.run() + if self.detect_anomaly: + with torch.autograd.detect_anomaly(): + self.fit_loop.run() + else: + self.fit_loop.run() def _run_evaluate(self) -> _EVALUATE_OUTPUT: if not self.is_global_zero and self.progress_bar_callback is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0c4833c903a66..dccd5456486e7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -834,6 +834,21 @@ def on_after_backward(self): assert not torch.isfinite(params).all() +def test_detect_anomaly_nan(tmpdir): + class CurrentModel(BoringModel): + def forward(self, x): + x /= 0 + return self.layer(x) + + model = CurrentModel() + trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) + with pytest.warns( + UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" + ): + with pytest.raises(RuntimeError, match=r".*returned nan values in its 0th output..*"): + trainer.fit(model) + + def test_trainer_interrupted_flag(tmpdir): """Test the flag denoting that a user interrupted training.""" From 60d05356e02cfa5c179dbe316e54c6c0a88ec966 Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Sat, 28 Aug 2021 23:06:02 +0530 Subject: [PATCH 02/20] Deprecate `terminate_on_nan` trainer argument. --- pytorch_lightning/trainer/trainer.py | 9 ++++++++- tests/trainer/test_trainer.py | 15 ++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a451cd3796863..14a3931469d7b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -148,7 +148,7 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, - terminate_on_nan: bool = False, + terminate_on_nan: Optional[bool] = None, detect_anomaly: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: Optional[bool] = None, @@ -306,6 +306,8 @@ def __init__( terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. + Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in + the v1.7 release. Please use trainer argument `detect_anomaly` instead. detect_anomaly: Enable anomaly detection for the autograd engine. @@ -435,6 +437,11 @@ def __init__( prepare_data_per_node, ) + if terminate_on_nan is not None: + rank_zero_deprecation( + "Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." + ) + # init training tricks self.training_tricks_connector.on_trainer_init( gradient_clip_val, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dccd5456486e7..e2dc3060f12e9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -803,7 +803,10 @@ def training_step(self, batch, batch_idx): model = CurrentModel() # fit model - trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_inf + 1), terminate_on_nan=True) + with pytest.deprecated_call( + match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." + ): + trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_inf + 1), terminate_on_nan=True) with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"): trainer.fit(model) @@ -823,7 +826,10 @@ def on_after_backward(self): torch.nn.init.constant_(self.layer.bias, math.nan) model = CurrentModel() - trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), terminate_on_nan=True) + with pytest.deprecated_call( + match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." + ): + trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), terminate_on_nan=True) with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `layer.bias`.*"): trainer.fit(model) @@ -844,9 +850,8 @@ def forward(self, x): trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) with pytest.warns( UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" - ): - with pytest.raises(RuntimeError, match=r".*returned nan values in its 0th output..*"): - trainer.fit(model) + ) and pytest.raises(RuntimeError, match=r".*returned nan values in its 0th output..*"): + trainer.fit(model) def test_trainer_interrupted_flag(tmpdir): From 577c30167b7a4ab1df8d0ca5271ea2d834a27dd7 Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Sun, 29 Aug 2021 08:30:03 +0530 Subject: [PATCH 03/20] Minor Changes for deprecation warnings. --- .../connectors/training_trick_connector.py | 16 +++++++++++++--- pytorch_lightning/trainer/trainer.py | 11 ++++------- tests/deprecated_api/test_remove_1-7.py | 7 +++++++ tests/trainer/test_trainer.py | 10 ++-------- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 2cfdabb55a26b..da9aedf73ab2b 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Union +from typing import Dict, Optional, Union from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities import rank_zero_deprecation class TrainingTricksConnector: @@ -28,11 +29,20 @@ def on_trainer_init( gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], accumulate_grad_batches: Union[int, Dict[int, int]], - terminate_on_nan: bool, + terminate_on_nan: Optional[bool], detect_anomaly: bool, ): - self.trainer.terminate_on_nan = terminate_on_nan + if terminate_on_nan is None: + self.trainer.terminate_on_nan = detect_anomaly + else: + # emit a deprecation warning + rank_zero_deprecation( + "Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release." + "Please use trainer argument `detect_anomaly` instead." + ) + self.trainer.terminate_on_nan = terminate_on_nan + self.trainer.detect_anomaly = detect_anomaly # gradient clipping diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 14a3931469d7b..00f747fdd580a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -306,8 +306,10 @@ def __init__( terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. - Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in - the v1.7 release. Please use trainer argument `detect_anomaly` instead. + + .. deprecated:: v1.5 + `Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in + the v1.7 release. Please use trainer argument `detect_anomaly` instead.` detect_anomaly: Enable anomaly detection for the autograd engine. @@ -437,11 +439,6 @@ def __init__( prepare_data_per_node, ) - if terminate_on_nan is not None: - rank_zero_deprecation( - "Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." - ) - # init training tricks self.training_tricks_connector.on_trainer_init( gradient_clip_val, diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 7581bf2b0c142..55d0dff8eb197 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -87,3 +87,10 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!" ): _ = Trainer(prepare_data_per_node=False) + + +def test_v1_7_0_trainer_terminate_on_nan(tmpdir): + with pytest.deprecated_call( + match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." + ): + _ = Trainer(terminate_on_nan=True) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e2dc3060f12e9..b629da98f8549 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -803,10 +803,7 @@ def training_step(self, batch, batch_idx): model = CurrentModel() # fit model - with pytest.deprecated_call( - match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." - ): - trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_inf + 1), terminate_on_nan=True) + trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_inf + 1), terminate_on_nan=True) with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"): trainer.fit(model) @@ -826,10 +823,7 @@ def on_after_backward(self): torch.nn.init.constant_(self.layer.bias, math.nan) model = CurrentModel() - with pytest.deprecated_call( - match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." - ): - trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), terminate_on_nan=True) + trainer = Trainer(default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), terminate_on_nan=True) with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `layer.bias`.*"): trainer.fit(model) From 2f3632bfd28cfafeb79f50a1e4243754aefd830e Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Sun, 29 Aug 2021 18:10:24 +0530 Subject: [PATCH 04/20] Fix PEP8 errors --- .../trainer/connectors/training_trick_connector.py | 8 ++++---- tests/deprecated_api/test_remove_1-7.py | 3 ++- tests/trainer/test_trainer.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index da9aedf73ab2b..f6314d0a6d8ed 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -14,9 +14,8 @@ from typing import Dict, Optional, Union from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_deprecation class TrainingTricksConnector: @@ -38,8 +37,9 @@ def on_trainer_init( else: # emit a deprecation warning rank_zero_deprecation( - "Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release." - "Please use trainer argument `detect_anomaly` instead." + "Trainer argument `terminate_on_nan` was deprecated in v1.5 release" + " and will be removed in the v1.7 release." + " Please use trainer argument `detect_anomaly` instead." ) self.trainer.terminate_on_nan = terminate_on_nan diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 55d0dff8eb197..711dac3426474 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -91,6 +91,7 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): def test_v1_7_0_trainer_terminate_on_nan(tmpdir): with pytest.deprecated_call( - match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in the v1.7 release. Please use trainer argument `detect_anomaly` instead." + match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed" + " in the v1.7 release. Please use trainer argument `detect_anomaly` instead." ): _ = Trainer(terminate_on_nan=True) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b629da98f8549..437c8705df01e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -844,7 +844,7 @@ def forward(self, x): trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) with pytest.warns( UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" - ) and pytest.raises(RuntimeError, match=r".*returned nan values in its 0th output..*"): + ) and pytest.raises(RuntimeError, match=r".*returned nan values in its 0th output.*"): trainer.fit(model) From 0eb3a283ded611e9e26e7c59bcd604259791dadd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Aug 2021 12:43:25 +0000 Subject: [PATCH 05/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-7.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index f6b3cc9ca18f3..8b26ec3952290 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -90,6 +90,7 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): ): _ = Trainer(prepare_data_per_node=False) + def test_v1_7_0_trainer_terminate_on_nan(tmpdir): with pytest.deprecated_call( match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed" @@ -97,6 +98,7 @@ def test_v1_7_0_trainer_terminate_on_nan(tmpdir): ): _ = Trainer(terminate_on_nan=True) + def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): model = BoringModel() @@ -121,4 +123,4 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): @mock.patch("pytorch_lightning.loggers.test_tube.Experiment") def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): - _ = TestTubeLogger(tmpdir) \ No newline at end of file + _ = TestTubeLogger(tmpdir) From b42609582029a13a220b6e7726a62e82975056f0 Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Sun, 29 Aug 2021 19:42:16 +0530 Subject: [PATCH 06/20] Fix error with test_detect_anomaly_nan. --- tests/trainer/test_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 437c8705df01e..f142713bbde1c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -844,8 +844,9 @@ def forward(self, x): trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) with pytest.warns( UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" - ) and pytest.raises(RuntimeError, match=r".*returned nan values in its 0th output.*"): - trainer.fit(model) + ): + with pytest.raises(RuntimeError, match=r"Function 'MseLossBackward' returned nan values in its 0th output."): + trainer.fit(model) def test_trainer_interrupted_flag(tmpdir): From 85bf40ec2634aab2a5206ed833b573f3f70e1c29 Mon Sep 17 00:00:00 2001 From: yopknopixx <30761130+yopknopixx@users.noreply.github.com> Date: Sun, 29 Aug 2021 21:05:54 +0530 Subject: [PATCH 07/20] Update pytorch_lightning/trainer/trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 39da02af52737..1af2680ab5a26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1128,10 +1128,7 @@ def _run_train(self) -> None: self.reset_train_val_dataloaders(model) self.fit_loop.trainer = self - if self.detect_anomaly: - with torch.autograd.detect_anomaly(): - self.fit_loop.run() - else: + with torch.autograd.set_detect_anomaly(self.detect_anomaly): self.fit_loop.run() def _run_evaluate(self) -> _EVALUATE_OUTPUT: From a62a07a37becf9564692a402ae425d84da3fb8bb Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Thu, 2 Sep 2021 08:26:40 +0530 Subject: [PATCH 08/20] Recommended Changes --- CHANGELOG.md | 6 ++++ pytorch_lightning/core/hooks.py | 2 +- .../loops/batch/training_batch_loop.py | 4 +-- pytorch_lightning/profiler/pytorch.py | 28 ++----------------- .../connectors/training_trick_connector.py | 10 ++----- pytorch_lightning/trainer/trainer.py | 2 +- tests/deprecated_api/test_remove_1-5.py | 5 ---- tests/profiler/test_profiler.py | 6 ---- tests/trainer/test_trainer.py | 8 +++--- 9 files changed, 19 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d3dcedbe003a..dba6852dbdcf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175)) + + - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` @@ -231,6 +234,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943)) +- Removed deprecated `profiled_functions` argument from `PyTorchProfiler` ([#9178](https://github.com/PyTorchLightning/pytorch-lightning/pull/9178)) + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( @@ -725,6 +730,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated + - Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) - Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) - Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 220ac589f130c..479a2eae0f8bc 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from pytorch_lightning.utilities import move_data_to_device from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation class ModelHooks: diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 3f94a0181672e..290e8d978691a 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -277,7 +277,7 @@ def backward_fn(loss: Tensor): self.trainer.dev_debugger.track_train_loss_history(batch_idx, loss) # check if loss or model weights are nan - if self.trainer.terminate_on_nan: + if self.trainer.terminate_on_nan or self.trainer.detect_anomaly: check_finite_loss(self.trainer.lightning_module, loss) return loss @@ -295,7 +295,7 @@ def _process_closure_result(self, opt_closure_result: Optional[ClosureResult]) - return # check if loss or model weights are nan - if self.trainer.terminate_on_nan: + if self.trainer.terminate_on_nan or self.trainer.detect_anomaly: check_finite_loss(self.trainer.lightning_module, opt_closure_result.loss) def _training_step( diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 960860953f11b..a196495a2b720 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -24,7 +24,7 @@ from torch.autograd.profiler import record_function from pytorch_lightning.profiler.base import BaseProfiler -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE @@ -222,7 +222,6 @@ def __init__( sort_by_key: Optional[str] = None, record_functions: Set[str] = None, record_module_names: bool = True, - profiled_functions: Optional[List] = None, output_filename: Optional[str] = None, **profiler_kwargs: Any, ) -> None: @@ -277,14 +276,12 @@ def __init__( """ super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - record_functions = self.__deprecation_check(profiled_functions, record_functions) - self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) self._emit_nvtx = emit_nvtx self._export_to_chrome = export_to_chrome self._row_limit = row_limit self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" - self._user_record_functions = record_functions + self._user_record_functions = record_functions or set() self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS self._record_module_names = record_module_names @@ -331,27 +328,6 @@ def _init_kineto(self, profiler_kwargs: Any) -> None: with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph self._profiler_kwargs["with_stack"] = with_stack - def __deprecation_check( - self, profiled_functions: Optional[List[str]], record_functions: Optional[Set[str]] - ) -> Set[str]: - if record_functions is None: - record_functions = set() - - if profiled_functions is not None: - rank_zero_deprecation( - "`PyTorchProfiler.profiled_functions` has been renamed to" - " `record_functions` in v1.3 and will be removed in v1.5" - ) - if not record_functions: - record_functions |= set(profiled_functions) - else: - raise MisconfigurationException( - "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." - " Please use only the later." - ) - - return record_functions - @staticmethod def _default_schedule() -> Optional[callable]: if _KINETO_AVAILABLE: diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index f6314d0a6d8ed..fd8ea4ff7d711 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -28,21 +28,17 @@ def on_trainer_init( gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], accumulate_grad_batches: Union[int, Dict[int, int]], - terminate_on_nan: Optional[bool], detect_anomaly: bool, + terminate_on_nan: Optional[bool] = None, ): - if terminate_on_nan is None: - self.trainer.terminate_on_nan = detect_anomaly - else: - # emit a deprecation warning + if terminate_on_nan: rank_zero_deprecation( "Trainer argument `terminate_on_nan` was deprecated in v1.5 release" " and will be removed in the v1.7 release." " Please use trainer argument `detect_anomaly` instead." ) - self.trainer.terminate_on_nan = terminate_on_nan - + self.trainer.terminate_on_nan = terminate_on_nan self.trainer.detect_anomaly = detect_anomaly # gradient clipping diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1af2680ab5a26..353054e3c999b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -444,8 +444,8 @@ def __init__( gradient_clip_algorithm, track_grad_norm, accumulate_grad_batches, - terminate_on_nan, detect_anomaly, + terminate_on_nan, ) self._setup_on_init(num_sanity_val_steps) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 40dfc069ac449..e3f4d987ef5e3 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -25,11 +25,6 @@ from tests.helpers.utils import no_warning_call -def test_v1_5_0_legacy_profiler_argument(): - with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): - PyTorchProfiler(profiled_functions=[]) - - def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match="has been renamed to `Trainer.sanity_checking`"): diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 2145ab83e9cdb..327c590ffb01e 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -258,12 +258,6 @@ def test_pytorch_profiler_describe(pytorch_profiler): assert len(data) > 0 -def test_pytorch_profiler_raises(pytorch_profiler): - """Ensure errors are raised where expected.""" - with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"): - PyTorchProfiler(profiled_functions=["a"], record_functions=["b"]) - - def test_advanced_profiler_cprofile_deepcopy(tmpdir): """Checks for pickle issue reported in #6522""" model = BoringModel() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 30ea5650fd445..5bde751307072 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -842,10 +842,10 @@ def forward(self, x): model = CurrentModel() trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) - with pytest.warns( - UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" - ): - with pytest.raises(RuntimeError, match=r"Function 'MseLossBackward' returned nan values in its 0th output."): + with pytest.raises(RuntimeError, match=r"Function 'MseLossBackward' returned nan values in its 0th output."): + with pytest.warns( + UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" + ): trainer.fit(model) From ae7daac8b3435ae0b006a62571757e87dce23e79 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Sep 2021 03:07:57 +0000 Subject: [PATCH 09/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 22338b9c1a048..fcb3f2999a71c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -849,7 +849,6 @@ def forward(self, x): trainer.fit(model) - def test_on_exception_hook(tmpdir): """Test the on_exception callback hook and the trainer interrupted flag.""" From 085c0a2521e979b0ccd73d1470dedd17f1cb8cdf Mon Sep 17 00:00:00 2001 From: Yog Dharaskar Date: Wed, 8 Sep 2021 21:52:13 +0530 Subject: [PATCH 10/20] Recommended changes --- pytorch_lightning/loops/utilities.py | 2 +- pytorch_lightning/profiler/pytorch.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 154680535ef73..e976a29724f1c 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -99,7 +99,7 @@ def _process_training_step_output( elif isinstance(training_step_output, torch.Tensor): loss = training_step_output - if trainer.terminate_on_nan: + if trainer.terminate_on_nan or trainer.detect_anomaly: check_finite_loss(loss) # the loss shouldn't be moved to cpu. diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index e8e3db232c6fb..6a379a7548f16 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -229,7 +229,6 @@ def __init__( sort_by_key: Optional[str] = None, record_functions: Set[str] = None, record_module_names: bool = True, - output_filename: Optional[str] = None, **profiler_kwargs: Any, ) -> None: """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. From 5ba2c7d0925a969b8b9718fd7370a53723a0a7f8 Mon Sep 17 00:00:00 2001 From: yopknopixx <30761130+yopknopixx@users.noreply.github.com> Date: Thu, 9 Sep 2021 07:41:03 +0530 Subject: [PATCH 11/20] Update pytorch_lightning/loops/utilities.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/loops/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index e976a29724f1c..da4bcff575e83 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -99,7 +99,7 @@ def _process_training_step_output( elif isinstance(training_step_output, torch.Tensor): loss = training_step_output - if trainer.terminate_on_nan or trainer.detect_anomaly: + if trainer.terminate_on_nan and not trainer.detect_anomaly: check_finite_loss(loss) # the loss shouldn't be moved to cpu. From 9772e7def12d36624837add8ee2d8e1eb47f45ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 16:17:15 +0200 Subject: [PATCH 12/20] update tests --- .../connectors/training_trick_connector.py | 17 +++++++---------- pytorch_lightning/trainer/trainer.py | 5 ++--- tests/deprecated_api/test_remove_1-7.py | 3 +-- tests/trainer/test_trainer.py | 15 --------------- 4 files changed, 10 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 24c3765513932..a6d635a3d7f55 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Union, Optional from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -26,18 +26,15 @@ def on_trainer_init( gradient_clip_val: Union[int, float], gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], - terminate_on_nan: bool, + terminate_on_nan: Optional[bool], ): - if not isinstance(terminate_on_nan, bool): - raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.") - - if terminate_on_nan: + if terminate_on_nan is not None: rank_zero_deprecation( - "Trainer argument `terminate_on_nan` was deprecated in v1.5 release" - " and will be removed in the v1.7 release." - " Please use trainer argument `detect_anomaly` instead." + "Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7." + " Please use `Trainer(detect_anomaly=True)` instead." ) - self.trainer.terminate_on_nan = terminate_on_nan + if not isinstance(terminate_on_nan, bool): + raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.") # gradient clipping if not isinstance(gradient_clip_val, (int, float)): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5540d0d9d1c0b..b4dcab831fd13 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -178,7 +178,6 @@ def __init__( move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", stochastic_weight_avg: bool = False, - detect_anomaly: bool = False, ): r""" Customize every aspect of training via flags. @@ -353,8 +352,8 @@ def __init__( end of each training batch, if any of the parameters or the loss are NaN or +/-inf. .. deprecated:: v1.5 - `Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed in - the v1.7 release. Please use trainer argument `detect_anomaly` instead.` + Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7. + Please use ``detect_anomaly`` instead. detect_anomaly: Enable anomaly detection for the autograd engine. diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 76c2029d11fc3..9cafc0b7ed813 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -124,8 +124,7 @@ def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir): def test_v1_7_0_trainer_terminate_on_nan(tmpdir): with pytest.deprecated_call( - match="Trainer argument `terminate_on_nan` was deprecated in v1.5 release and will be removed" - " in the v1.7 release. Please use trainer argument `detect_anomaly` instead." + match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7" ): _ = Trainer(terminate_on_nan=True) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dee81ddaf875c..acb0c10df6c63 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -957,21 +957,6 @@ def on_after_backward(self): assert not torch.isfinite(params).all() -def test_detect_anomaly_nan(tmpdir): - class CurrentModel(BoringModel): - def forward(self, x): - x /= 0 - return self.layer(x) - - model = CurrentModel() - trainer = Trainer(default_root_dir=tmpdir, detect_anomaly=True) - with pytest.raises(RuntimeError, match=r"Function 'MseLossBackward' returned nan values in its 0th output."): - with pytest.warns( - UserWarning, match=r".*Error detected in MseLossBackward. Traceback of forward call that caused the error.*" - ): - trainer.fit(model) - - def test_on_exception_hook(tmpdir): """Test the on_exception callback hook and the trainer interrupted flag.""" From 132446b32978baf21356f886bcc50b714b2d6563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 16:18:29 +0200 Subject: [PATCH 13/20] reset --- _notebooks | 1 - pytorch_lightning/profiler/pytorch.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 4fe3370eac9c4..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4fe3370eac9c448eceb36b835ff49ca30de7d404 diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 6a379a7548f16..8bdbadffec15b 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -281,6 +281,7 @@ def __init__( If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. """ super().__init__(dirpath=dirpath, filename=filename) + self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) self._emit_nvtx = emit_nvtx self._export_to_chrome = export_to_chrome From 5a31539270f235c4f565f79c825e4a3e06563def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 16:18:49 +0200 Subject: [PATCH 14/20] reset _notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..a2fb6468112b7 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From 9fa71cdb85bf4f913e0d27a45a7a5959503ceb79 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Oct 2021 14:20:02 +0000 Subject: [PATCH 15/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/training_trick_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index a6d635a3d7f55..5165056d95391 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Optional +from typing import Optional, Union from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException From 650dc307e3a3968d1f64525747606296c960c2b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 16:22:32 +0200 Subject: [PATCH 16/20] undo empty line --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc24af1fea9ea..6c59417aeb6d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1010,7 +1010,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated - - Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) - Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) - Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) From 75c76f5bc08df46f2ed0debb8a0b2d862735f900 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 16:24:32 +0200 Subject: [PATCH 17/20] extend test --- tests/deprecated_api/test_remove_1-7.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 9cafc0b7ed813..8aaffe0bdf79b 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -122,11 +122,12 @@ def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir): _ = Trainer(stochastic_weight_avg=True) -def test_v1_7_0_trainer_terminate_on_nan(tmpdir): +@pytest.mark.parametrize("terminate_on_nan", [True, False]) +def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan): with pytest.deprecated_call( match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7" ): - _ = Trainer(terminate_on_nan=True) + _ = Trainer(terminate_on_nan=terminate_on_nan) def test_v1_7_0_deprecated_on_task_dataloader(tmpdir): From 944694cc8996e7a565e4d495da5c559ff3fae466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 17:38:17 +0200 Subject: [PATCH 18/20] fix merge error --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b4dcab831fd13..3b4f1473ab248 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -168,7 +168,6 @@ def __init__( auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: Optional[bool] = None, - detect_anomaly: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: Optional[bool] = None, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, @@ -178,6 +177,7 @@ def __init__( move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", stochastic_weight_avg: bool = False, + detect_anomaly: bool = False, ): r""" Customize every aspect of training via flags. From 8e8a9e92e789e237c062a6c7ae5da8d44fa05bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 17:40:12 +0200 Subject: [PATCH 19/20] swap order to support positional args --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3b4f1473ab248..162de67a761dd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -167,7 +167,7 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, - terminate_on_nan: Optional[bool] = None, + detect_anomaly: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: Optional[bool] = None, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, @@ -177,7 +177,7 @@ def __init__( move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", stochastic_weight_avg: bool = False, - detect_anomaly: bool = False, + terminate_on_nan: Optional[bool] = None, ): r""" Customize every aspect of training via flags. From 9de66ea396822b11a86ce217f9666a259e2a8bef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 18:41:36 +0200 Subject: [PATCH 20/20] add additional asserts for trainer property --- tests/deprecated_api/test_remove_1-7.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 8aaffe0bdf79b..995fa0f2c61b1 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -127,7 +127,9 @@ def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan): with pytest.deprecated_call( match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7" ): - _ = Trainer(terminate_on_nan=terminate_on_nan) + trainer = Trainer(terminate_on_nan=terminate_on_nan) + assert trainer.terminate_on_nan is terminate_on_nan + assert trainer._detect_anomaly is False def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):