Skip to content

Commit 36b9ff2

Browse files
ananthsubawaelchlikaushikb11Bordarohitgr7
authored
Deprecate stochastic_weight_avg from the Trainer constructor (#8989)
* Deprecate `stochastic_weight_avg` from the `Trainer` constructor * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent a4bc0ac commit 36b9ff2

File tree

4 files changed

+18
-0
lines changed

4 files changed

+18
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
277277
- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
278278

279279

280+
- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor in favor of adding the `StochasticWeightAveraging` callback directly to the list of callbacks ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989))
281+
282+
280283
### Removed
281284

282285
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def on_trainer_init(
5252
# init folder paths for checkpoint + weights save callbacks
5353
self.trainer._default_root_dir = default_root_dir or os.getcwd()
5454
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
55+
if stochastic_weight_avg:
56+
rank_zero_deprecation(
57+
"Setting `Trainer(stochastic_weight_avg=True)` is deprecated in v1.5 and will be removed in v1.7."
58+
" Please pass `pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`"
59+
" directly to the Trainer's `callbacks` argument instead."
60+
)
5561
self.trainer._stochastic_weight_avg = stochastic_weight_avg
5662

5763
# init callbacks

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ def __init__(
375375
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
376376
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>`_.
377377
378+
.. deprecated:: v1.5
379+
``stochastic_weight_avg`` has been deprecated in v1.5 and will be removed in v1.7.
380+
Please pass :class:`~pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`
381+
directly to the Trainer's ``callbacks`` argument instead.
378382
"""
379383
super().__init__()
380384
Trainer._log_api_event("init")

tests/deprecated_api/test_remove_1-7.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
117117
_ = Trainer(prepare_data_per_node=False)
118118

119119

120+
def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir):
121+
with pytest.deprecated_call(match=r"Setting `Trainer\(stochastic_weight_avg=True\)` is deprecated in v1.5"):
122+
_ = Trainer(stochastic_weight_avg=True)
123+
124+
120125
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
121126
class CustomBoringModel(BoringModel):
122127
def on_train_dataloader(self):

0 commit comments

Comments
 (0)