Skip to content

Commit 7755572

Browse files
tchatonrohitgr7Borda
authored
Check if optimizer supports closure (#4981)
* check if optimizer support closure * cleanup test * resolve tests * resolve flake * update test due to patch limit * update * update dep * Update tests/core/test_lightning_optimizer.py Co-authored-by: Rohit Gupta <[email protected]> * Update tests/core/test_lightning_optimizer.py Co-authored-by: Rohit Gupta <[email protected]> * resolve bug * update test * resolve tests * Update requirements/extra.txt Co-authored-by: Jirka Borovec <[email protected]> * remove bolts dep * remove bolts * add missing bolts dep for tests * remove need for bolts Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 4e6a871 commit 7755572

File tree

6 files changed

+81
-22
lines changed

6 files changed

+81
-22
lines changed

pytorch_lightning/core/optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import types
1516
from typing import Any, Callable, Optional
1617
from weakref import proxy
@@ -60,7 +61,7 @@ def __init__(self,
6061
self._trainer = None
6162
self._optimizer = optimizer
6263
self._accumulate_grad_batches = accumulate_grad_batches
63-
self._automatic_optimization = None
64+
self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters
6465
self._optimizer_idx = None
6566

6667
@property
@@ -73,7 +74,6 @@ def accumulate_grad_batches(self, accumulate_grad_batches):
7374

7475
def _on_trainer_init(self, trainer):
7576
self._trainer = proxy(trainer)
76-
self._automatic_optimization = trainer.train_loop.automatic_optimization
7777
for opt_idx, opt in enumerate(trainer.optimizers):
7878
if opt == self._optimizer:
7979
self._optimizer_idx = opt_idx
@@ -111,7 +111,11 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
111111

112112
else:
113113
with trainer.profiler.profile(profiler_name):
114-
optimizer.step(closure=closure, *args, **kwargs)
114+
if self._support_closure:
115+
optimizer.step(closure=closure, *args, **kwargs)
116+
else:
117+
closure()
118+
optimizer.step(*args, **kwargs)
115119

116120
accelerator_backend = trainer.accelerator_backend
117121
if accelerator_backend is not None and accelerator_backend.rpc_enabled:

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _module_available(module_path: str) -> bool:
5454
OMEGACONF_AVAILABLE = _module_available("omegaconf")
5555
HYDRA_AVAILABLE = _module_available("hydra")
5656
HOROVOD_AVAILABLE = _module_available("horovod.torch")
57+
BOLTS_AVAILABLE = _module_available("pl_bolts")
5758

5859
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
5960
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
77
onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
10+
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip

tests/core/test_lightning_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir):
5555

5656
class TestModel(BoringModel):
5757

58+
def training_step(self, batch, batch_idx, optimizer_idx):
59+
output = self.layer(batch)
60+
loss = self.loss(batch, output)
61+
return {"loss": loss}
62+
5863
def configure_optimizers(self):
5964
optimizer = SGD(self.layer.parameters(), lr=0.1)
6065
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)

tests/core/test_lightning_optimizer.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import torch.nn as nn
2020
from torch.optim import Adam, Optimizer
2121

22+
import pytorch_lightning as pl
2223
from pytorch_lightning import LightningModule, Trainer
24+
from pytorch_lightning.callbacks import ModelCheckpoint
2325
from pytorch_lightning.core.optimizer import LightningOptimizer
2426
from pytorch_lightning.utilities.exceptions import MisconfigurationException
25-
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
27+
from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset
2628

2729

2830
def test_lightning_optimizer(tmpdir):
@@ -80,8 +82,8 @@ def configure_optimizers(self):
8082
assert trainer.optimizers[0].__repr__() == expected
8183

8284

83-
@patch("torch.optim.Adam.step")
84-
@patch("torch.optim.SGD.step")
85+
@patch("torch.optim.Adam.step", autospec=True)
86+
@patch("torch.optim.SGD.step", autospec=True)
8587
def test_lightning_optimizer_manual_optimization(mock_sgd_step, mock_adam_step, tmpdir):
8688
"""
8789
Test that the user can use our LightningOptimizer. Not recommended for now.
@@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
9698
output = self.layer(batch)
9799
loss_1 = self.loss(batch, output)
98100
self.manual_backward(loss_1, opt_1)
99-
opt_1.step(idx="1")
101+
opt_1.step()
100102

101103
def closure():
102104
output = self.layer(batch)
103105
loss_2 = self.loss(batch, output)
104106
self.manual_backward(loss_2, opt_2)
105-
opt_2.step(closure=closure, idx="2")
107+
opt_2.step(closure=closure)
106108

107109
def configure_optimizers(self):
108110
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
@@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool:
133135
assert len(mock_adam_step.mock_calls) == 8
134136

135137

136-
@patch("torch.optim.Adam.step")
137-
@patch("torch.optim.SGD.step")
138+
@patch("torch.optim.Adam.step", autospec=True)
139+
@patch("torch.optim.SGD.step", autospec=True)
138140
def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(mock_sgd_step, mock_adam_step, tmpdir):
139141
"""
140142
Test that the user can use our LightningOptimizer. Not recommended.
@@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
149151
output = self.layer(batch)
150152
loss_1 = self.loss(batch, output)
151153
self.manual_backward(loss_1, opt_1)
152-
opt_1.step(idx="1")
154+
opt_1.step()
153155

154156
def closure():
155157
output = self.layer(batch)
156158
loss_2 = self.loss(batch, output)
157159
self.manual_backward(loss_2, opt_2)
158-
opt_2.step(closure=closure, idx="2")
160+
opt_2.step(closure=closure)
159161

160162
def configure_optimizers(self):
161163
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
@@ -195,9 +197,8 @@ def test_state(tmpdir):
195197
assert isinstance(lightning_optimizer, Adam)
196198
assert isinstance(lightning_optimizer, Optimizer)
197199
lightning_dict = {}
198-
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx",
199-
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization",
200-
"_accumulate_grad_batches"]
200+
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure",
201+
"_trainer"]
201202
for k, v in lightning_optimizer.__dict__.items():
202203
if k not in special_attrs:
203204
lightning_dict[k] = v
@@ -206,6 +207,55 @@ def test_state(tmpdir):
206207
assert optimizer.state == lightning_optimizer.state
207208

208209

210+
def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir):
211+
class OptimizerWrapper(object):
212+
def __init__(self, optimizer):
213+
self.optim = optimizer
214+
self.state_dict = self.optim.state_dict
215+
self.load_state_dict = self.optim.load_state_dict
216+
self.zero_grad = self.optim.zero_grad
217+
self.add_param_group = self.optim.add_param_group
218+
self.__setstate__ = self.optim.__setstate__
219+
self.__getstate__ = self.optim.__getstate__
220+
self.__repr__ = self.optim.__repr__
221+
222+
@property
223+
def __class__(self):
224+
return Optimizer
225+
226+
@property
227+
def state(self):
228+
return self.optim.state
229+
230+
@property
231+
def param_groups(self):
232+
return self.optim.param_groups
233+
234+
@param_groups.setter
235+
def param_groups(self, value):
236+
self.optim.param_groups = value
237+
238+
def step(self):
239+
# wrongly defined step. Should contain closure
240+
self.optim.step(closure=None)
241+
242+
class TestLightningOptimizerModel(BoringModel):
243+
244+
def configure_optimizers(self):
245+
optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
246+
optimizer = OptimizerWrapper(optimizer)
247+
return [optimizer]
248+
249+
model = TestLightningOptimizerModel()
250+
trainer = Trainer(
251+
default_root_dir=tmpdir,
252+
max_epochs=1,
253+
weights_summary=None,
254+
log_every_n_steps=1,
255+
)
256+
trainer.fit(model)
257+
258+
209259
def test_lightning_optimizer_automatic_optimization(tmpdir):
210260
"""
211261
Test lightning optimize works with make_optimizer_step in automatic_optimization

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def optimizer_closure():
825825
retain_graph = num_backward != backward_idx # noqa E225
826826
self.manual_backward(loss_1, opt, retain_graph=retain_graph)
827827

828-
opt.step(1, closure=optimizer_closure, something="new")
828+
opt.step(closure=optimizer_closure)
829829

830830
def training_epoch_end(self, outputs) -> None:
831831
# outputs should be an array with an entry per optimizer
@@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool:
855855
)
856856

857857
trainer.fit(model)
858-
expected_calls = [call(1, closure=ANY, something="new") for s in range(2)]
858+
expected_calls = [call() for s in range(2)]
859859
step_mock.assert_has_calls(expected_calls)
860860

861861

@@ -902,7 +902,7 @@ def dis_closure():
902902
if batch_idx % 4 == 0 :
903903
# Note: Set make_optimizer_step to True or it will use by default
904904
# Trainer(accumulate_grad_batches=x)
905-
opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam')
905+
opt_dis.step(closure=dis_closure, make_optimizer_step=True)
906906

907907
def training_epoch_end(self, outputs) -> None:
908908
# outputs should be an array with an entry per optimizer
@@ -933,10 +933,9 @@ def automatic_optimization(self) -> bool:
933933
)
934934

935935
trainer.fit(model)
936-
expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)]
936+
expected_calls = [call(optim='sgd') for s in range(4)]
937937
mock_sgd_step.assert_has_calls(expected_calls)
938-
939-
expected_calls = [call(closure=ANY, optim='adam') for s in range(2)]
938+
expected_calls = [call() for s in range(2)]
940939
mock_adam_step.assert_has_calls(expected_calls)
941940

942941

0 commit comments

Comments
 (0)