From 53b149b2cbb3e19c181c8101e4fe1f167e48ef63 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 18 Aug 2021 21:18:15 -0700 Subject: [PATCH 1/3] Deprecate `stochastic_weight_avg` from the `Trainer` constructor --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/connectors/callback_connector.py | 7 +++++++ pytorch_lightning/trainer/trainer.py | 4 ++++ tests/deprecated_api/test_remove_1-7.py | 7 ++++++- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99996c6281938..eb496c4ffe7fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,7 +106,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) -- +- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor - diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index cd8183b68ec82..c4ba348310b00 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -19,6 +19,7 @@ from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import rank_zero_deprecation class CallbackConnector: @@ -39,6 +40,12 @@ def on_trainer_init( # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir + if stochastic_weight_avg: + rank_zero_deprecation( + "Setting `Trainer(stochastic_weight_avg=True)` is deprecated in v1.5 and will be removed in v1.7." + " Please pass `pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`" + " directly to the Trainer's ``callbacks`` argument instead." + ) self.trainer._stochastic_weight_avg = stochastic_weight_avg # init callbacks diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 887cdd46a9db2..d1a4cde415bdc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -331,6 +331,10 @@ def __init__( stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) _` + .. deprecated:: v1.5 + ``stochastic_weight_avg`` has been deprecated in v1.5 and will be removed in v1.7. + Please pass :class:`~pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging` + directly to the Trainer's ``callbacks`` argument instead. """ super().__init__() Trainer._log_api_event("init") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index d836f1427a110..76eec4e9f9ea4 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -15,7 +15,7 @@ import pytest -from pytorch_lightning import LightningDataModule +from pytorch_lightning import LightningDataModule, Trainer from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule @@ -80,3 +80,8 @@ def test_v1_7_0_datamodule_dims_property(tmpdir): _ = dm.dims with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): _ = LightningDataModule(dims=(1, 1, 1)) + + +def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir): + with pytest.deprecated_call(match=r"Setting `Trainer\(stochastic_weight_avg=True\)` is deprecated in v1.5"): + _ = Trainer(stochastic_weight_avg=True) From 6a58859e466509e1245d0ea803d4b5f208f8486a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 18 Aug 2021 21:20:59 -0700 Subject: [PATCH 2/3] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb496c4ffe7fc..37f5940fcad4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,7 +106,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) -- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor +- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989)) - From 9c49409e3f150187a40622445ba16a88d8652b17 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 19 Aug 2021 07:47:18 -0700 Subject: [PATCH 3/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37f5940fcad4b..d1622eb6c5506 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,7 +106,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) -- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989)) +- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor in favor of adding the `StochasticWeightAveraging` callback directly to the list of callbacks ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989)) - diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c4ba348310b00..bf6b8a1c16d62 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -44,7 +44,7 @@ def on_trainer_init( rank_zero_deprecation( "Setting `Trainer(stochastic_weight_avg=True)` is deprecated in v1.5 and will be removed in v1.7." " Please pass `pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`" - " directly to the Trainer's ``callbacks`` argument instead." + " directly to the Trainer's `callbacks` argument instead." ) self.trainer._stochastic_weight_avg = stochastic_weight_avg