1919import torch .nn as nn
2020from torch .optim import Adam , Optimizer
2121
22+ import pytorch_lightning as pl
2223from pytorch_lightning import LightningModule , Trainer
24+ from pytorch_lightning .callbacks import ModelCheckpoint
2325from pytorch_lightning .core .optimizer import LightningOptimizer
2426from pytorch_lightning .utilities .exceptions import MisconfigurationException
25- from tests .base .boring_model import BoringModel , RandomDictDataset , RandomDictStringDataset
27+ from tests .base .boring_model import BoringModel , RandomDataset , RandomDictDataset , RandomDictStringDataset
2628
2729
2830def test_lightning_optimizer (tmpdir ):
@@ -80,8 +82,8 @@ def configure_optimizers(self):
8082 assert trainer .optimizers [0 ].__repr__ () == expected
8183
8284
83- @patch ("torch.optim.Adam.step" )
84- @patch ("torch.optim.SGD.step" )
85+ @patch ("torch.optim.Adam.step" , autospec = True )
86+ @patch ("torch.optim.SGD.step" , autospec = True )
8587def test_lightning_optimizer_manual_optimization (mock_sgd_step , mock_adam_step , tmpdir ):
8688 """
8789 Test that the user can use our LightningOptimizer. Not recommended for now.
@@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
9698 output = self .layer (batch )
9799 loss_1 = self .loss (batch , output )
98100 self .manual_backward (loss_1 , opt_1 )
99- opt_1 .step (idx = "1" )
101+ opt_1 .step ()
100102
101103 def closure ():
102104 output = self .layer (batch )
103105 loss_2 = self .loss (batch , output )
104106 self .manual_backward (loss_2 , opt_2 )
105- opt_2 .step (closure = closure , idx = "2" )
107+ opt_2 .step (closure = closure )
106108
107109 def configure_optimizers (self ):
108110 optimizer_1 = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
@@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool:
133135 assert len (mock_adam_step .mock_calls ) == 8
134136
135137
136- @patch ("torch.optim.Adam.step" )
137- @patch ("torch.optim.SGD.step" )
138+ @patch ("torch.optim.Adam.step" , autospec = True )
139+ @patch ("torch.optim.SGD.step" , autospec = True )
138140def test_lightning_optimizer_manual_optimization_and_accumulated_gradients (mock_sgd_step , mock_adam_step , tmpdir ):
139141 """
140142 Test that the user can use our LightningOptimizer. Not recommended.
@@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
149151 output = self .layer (batch )
150152 loss_1 = self .loss (batch , output )
151153 self .manual_backward (loss_1 , opt_1 )
152- opt_1 .step (idx = "1" )
154+ opt_1 .step ()
153155
154156 def closure ():
155157 output = self .layer (batch )
156158 loss_2 = self .loss (batch , output )
157159 self .manual_backward (loss_2 , opt_2 )
158- opt_2 .step (closure = closure , idx = "2" )
160+ opt_2 .step (closure = closure )
159161
160162 def configure_optimizers (self ):
161163 optimizer_1 = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
@@ -195,9 +197,8 @@ def test_state(tmpdir):
195197 assert isinstance (lightning_optimizer , Adam )
196198 assert isinstance (lightning_optimizer , Optimizer )
197199 lightning_dict = {}
198- special_attrs = ["_accumulate_grad_batches" , "_optimizer" , "_optimizer_idx" ,
199- "_trainer" , "_use_accumulate_grad_batches_from_trainer" , "_automatic_optimization" ,
200- "_accumulate_grad_batches" ]
200+ special_attrs = ["_accumulate_grad_batches" , "_optimizer" , "_optimizer_idx" , "_support_closure" ,
201+ "_trainer" ]
201202 for k , v in lightning_optimizer .__dict__ .items ():
202203 if k not in special_attrs :
203204 lightning_dict [k ] = v
@@ -206,6 +207,55 @@ def test_state(tmpdir):
206207 assert optimizer .state == lightning_optimizer .state
207208
208209
210+ def test_lightning_optimizer_with_wrong_optimizer_interface (tmpdir ):
211+ class OptimizerWrapper (object ):
212+ def __init__ (self , optimizer ):
213+ self .optim = optimizer
214+ self .state_dict = self .optim .state_dict
215+ self .load_state_dict = self .optim .load_state_dict
216+ self .zero_grad = self .optim .zero_grad
217+ self .add_param_group = self .optim .add_param_group
218+ self .__setstate__ = self .optim .__setstate__
219+ self .__getstate__ = self .optim .__getstate__
220+ self .__repr__ = self .optim .__repr__
221+
222+ @property
223+ def __class__ (self ):
224+ return Optimizer
225+
226+ @property
227+ def state (self ):
228+ return self .optim .state
229+
230+ @property
231+ def param_groups (self ):
232+ return self .optim .param_groups
233+
234+ @param_groups .setter
235+ def param_groups (self , value ):
236+ self .optim .param_groups = value
237+
238+ def step (self ):
239+ # wrongly defined step. Should contain closure
240+ self .optim .step (closure = None )
241+
242+ class TestLightningOptimizerModel (BoringModel ):
243+
244+ def configure_optimizers (self ):
245+ optimizer = torch .optim .Adam (self .parameters (), lr = 0.1 )
246+ optimizer = OptimizerWrapper (optimizer )
247+ return [optimizer ]
248+
249+ model = TestLightningOptimizerModel ()
250+ trainer = Trainer (
251+ default_root_dir = tmpdir ,
252+ max_epochs = 1 ,
253+ weights_summary = None ,
254+ log_every_n_steps = 1 ,
255+ )
256+ trainer .fit (model )
257+
258+
209259def test_lightning_optimizer_automatic_optimization (tmpdir ):
210260 """
211261 Test lightning optimize works with make_optimizer_step in automatic_optimization
0 commit comments