diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 6ca72b8069d6d..06e6e9679d29f 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -191,37 +191,48 @@ override the :meth:`optimizer_step` function. For example, here step optimizer A every 2 batches and optimizer B every 4 batches -.. testcode:: +.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`. - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): - optimizer.step() +.. testcode:: def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): optimizer.zero_grad() # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): # update generator opt every 2 steps if optimizer_i == 0: if batch_nb % 2 == 0 : - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) # update discriminator opt every 4 steps if optimizer_i == 1: if batch_nb % 4 == 0 : - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) + +.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow. + +.. testcode:: + + def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): + optimizer.zero_grad() + + # Alternating schedule for optimizer steps (ie: GANs) + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + # update generator opt every 2 steps + if optimizer_i == 0: + optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0) - # ... - # add as many optimizers as you want + # update discriminator opt every 4 steps + if optimizer_i == 1: + optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0) Here we add a learning-rate warm up .. testcode:: # learning rate warm-up - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): # warm up lr if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) @@ -229,8 +240,20 @@ Here we add a learning-rate warm up pg['lr'] = lr_scale * self.hparams.learning_rate # update params - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) + +The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. + +.. testcode:: + + from pytorch_lightning.core.optimizer import LightningOptimizer + + # function hook in LightningModule + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + if not isinstance(optimizer, LightningOptimizer): + # wraps into LightingOptimizer only for running step + optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) + optimizer.step(closure=closure) ---------- diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f29e7f75bfbff..ef05ce69c1828 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1170,7 +1170,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): def optimizer_step( self, - *args, epoch: int = None, batch_idx: int = None, optimizer: Optimizer = None, @@ -1179,7 +1178,6 @@ def optimizer_step( on_tpu: bool = None, using_native_amp: bool = None, using_lbfgs: bool = None, - **kwargs, ) -> None: r""" Override this method to adjust the default way the @@ -1254,7 +1252,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, if not isinstance(optimizer, LightningOptimizer): # wraps into LightingOptimizer only for running step optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) - optimizer.step(closure=optimizer_closure, *args, **kwargs) + optimizer.step(closure=optimizer_closure) def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index e6b973b336e43..c8e9ff8b80a2f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import types from typing import Any, Callable, Optional from weakref import proxy @@ -58,12 +57,35 @@ def __init__(self, else: self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) - self._trainer = None self._optimizer = optimizer + self._trainer = None self._accumulate_grad_batches = accumulate_grad_batches - self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters self._optimizer_idx = None + @property + def defaults(self): + return self._optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self._optimizer.defaults = defaults + + @property + def state(self): + return self._optimizer.state + + @state.setter + def state(self, state): + self._optimizer.state = state + + @property + def param_groups(self): + return self._optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._optimizer.param_groups = param_groups + @property def accumulate_grad_batches(self): return self._accumulate_grad_batches @@ -111,11 +133,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n else: with trainer.profiler.profile(profiler_name): - if self._support_closure: - optimizer.step(closure=closure, *args, **kwargs) - else: - closure() - optimizer.step(*args, **kwargs) + optimizer.step(closure=closure, *args, **kwargs) accelerator_backend = trainer.accelerator_backend if accelerator_backend is not None and accelerator_backend.rpc_enabled: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20dfb0f4b380f..68a0f4781c9a9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -477,7 +477,7 @@ def _process_result(self, training_step_output, split_batch): return training_step_output_for_epoch_end - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs): + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) @@ -491,16 +491,14 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ # model hook model_ref.optimizer_step( - epoch=self.trainer.current_epoch, - batch_idx=batch_idx, - optimizer=optimizer, - optimizer_idx=opt_idx, - optimizer_closure=train_step_and_backward_closure, + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, on_tpu=self.trainer.use_tpu and TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, - *args, - **kwargs, ) def on_before_zero_grad(self, optimizer): diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index a7054a3a7ef49..e3a597063d02e 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -103,3 +103,47 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, assert sgd_zero_grad.call_count == 4 assert adam_step.call_count == 2 assert adam_zero_grad.call_count == 2 + + +@pytest.mark.parametrize("enable_pl_optimizer", [False, True]) +def test_params_groups_and_state_are_accessible(enable_pl_optimizer, tmpdir): + + with patch("torch.optim.SGD.step") as sgd_step, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \ + patch("torch.optim.Adam.step") as adam_step, \ + patch("torch.optim.Adam.zero_grad") as adam_zero_grad: + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * 0.01 + + optimizer.step(closure=closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + enable_pl_optimizer=enable_pl_optimizer + ) + + trainer.fit(model) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 16963a2af3c0d..a9fcf918cc699 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -193,12 +193,29 @@ def test_state(tmpdir): model = torch.nn.Linear(3, 4) optimizer = torch.optim.Adam(model.parameters()) lightning_optimizer = LightningOptimizer(optimizer) + + # test state + assert optimizer.state == lightning_optimizer.state + lightning_optimizer.state = optimizer.state + assert optimizer.state == lightning_optimizer.state + + # test param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + lightning_optimizer.param_groups = optimizer.param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + + # test defaults + assert optimizer.defaults == lightning_optimizer.defaults + lightning_optimizer.defaults = optimizer.defaults + assert optimizer.defaults == lightning_optimizer.defaults + assert isinstance(lightning_optimizer, LightningOptimizer) assert isinstance(lightning_optimizer, Adam) assert isinstance(lightning_optimizer, Optimizer) lightning_dict = {} special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure", - "_trainer"] + "_trainer", "__getstate__", "__setstate__", "state_dict", "load_state_dict", + "zero_grad", "__setstate__", "add_param_group"] for k, v in lightning_optimizer.__dict__.items(): if k not in special_attrs: lightning_dict[k] = v @@ -207,55 +224,6 @@ def test_state(tmpdir): assert optimizer.state == lightning_optimizer.state -def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir): - class OptimizerWrapper(object): - def __init__(self, optimizer): - self.optim = optimizer - self.state_dict = self.optim.state_dict - self.load_state_dict = self.optim.load_state_dict - self.zero_grad = self.optim.zero_grad - self.add_param_group = self.optim.add_param_group - self.__setstate__ = self.optim.__setstate__ - self.__getstate__ = self.optim.__getstate__ - self.__repr__ = self.optim.__repr__ - - @property - def __class__(self): - return Optimizer - - @property - def state(self): - return self.optim.state - - @property - def param_groups(self): - return self.optim.param_groups - - @param_groups.setter - def param_groups(self, value): - self.optim.param_groups = value - - def step(self): - # wrongly defined step. Should contain closure - self.optim.step(closure=None) - - class TestLightningOptimizerModel(BoringModel): - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=0.1) - optimizer = OptimizerWrapper(optimizer) - return [optimizer] - - model = TestLightningOptimizerModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - log_every_n_steps=1, - ) - trainer.fit(model) - - def test_lightning_optimizer_automatic_optimization(tmpdir): """ Test lightning optimize works with make_optimizer_step in automatic_optimization diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 9e369a874acd0..33d14e852b285 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool: ) trainer.fit(model) - expected_calls = [call() for s in range(2)] + expected_calls = [call(closure=ANY) for s in range(2)] step_mock.assert_has_calls(expected_calls) @@ -933,9 +933,9 @@ def automatic_optimization(self) -> bool: ) trainer.fit(model) - expected_calls = [call(optim='sgd') for s in range(4)] + expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)] mock_sgd_step.assert_has_calls(expected_calls) - expected_calls = [call() for s in range(2)] + expected_calls = [call(closure=ANY) for s in range(2)] mock_adam_step.assert_has_calls(expected_calls)