Skip to content

Commit e966d2e

Browse files
ananthsubtchaton
authored andcommitted
Add automatic optimization property setter to lightning module (#5169)
* add automatic optimization property setter to lightning module * Update test_manual_optimization.py Co-authored-by: chaton <[email protected]> (cherry picked from commit 8748293)
1 parent a8a4954 commit e966d2e

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(self, *args, **kwargs):
102102
self._running_manual_backward = False
103103
self._current_hook_fx_name = None
104104
self._current_dataloader_idx = None
105+
self._automatic_optimization: bool = True
105106

106107
def optimizers(self):
107108
opts = self.trainer.optimizers
@@ -151,7 +152,12 @@ def automatic_optimization(self) -> bool:
151152
"""
152153
If False you are responsible for calling .backward, .step, zero_grad.
153154
"""
154-
return True
155+
return self._automatic_optimization
156+
157+
@automatic_optimization.setter
158+
def automatic_optimization(self, automatic_optimization: bool) -> None:
159+
self._automatic_optimization = automatic_optimization
160+
155161

156162
def print(self, *args, **kwargs) -> None:
157163
r"""

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def test_multiple_optimizers_manual(tmpdir):
3333
Tests that only training_step can be used
3434
"""
3535
class TestModel(BoringModel):
36+
37+
def __init__(self):
38+
super().__init__()
39+
self.automatic_optimization = False
40+
3641
def training_step(self, batch, batch_idx, optimizer_idx):
3742
# manual
3843
(opt_a, opt_b) = self.optimizers()
@@ -69,10 +74,6 @@ def configure_optimizers(self):
6974
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
7075
return optimizer, optimizer_2
7176

72-
@property
73-
def automatic_optimization(self) -> bool:
74-
return False
75-
7677
model = TestModel()
7778
model.val_dataloader = None
7879

0 commit comments

Comments
 (0)