diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 058f3fc3fb01f..33bcfc645b488 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -334,23 +334,6 @@ def optimizer_step( model = model or self.lightning_module self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) - if not isinstance(model, pl.LightningModule): - # gradient clipping and norm tracking only available with a LightingModule/Trainer - return - - trainer = model.trainer - assert isinstance(trainer, pl.Trainer) - # TODO: this is done for the entire model but should be changed to per-optimizer - if opt_idx == 0: - self.precision_plugin._track_grad_norm(trainer) - self.precision_plugin._clip_gradients( - model, - optimizer, - opt_idx, - trainer.gradient_clip_val, - gradient_clip_algorithm=trainer.gradient_clip_algorithm, - ) - def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients.""" model_ref = self.lightning_module diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 23a7de220ba04..96c7c333f795f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -258,10 +258,13 @@ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> The hook is only called if gradients do not need to be accumulated. See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`. + If using native AMP, the loss will be unscaled before calling this hook. See these `docs `__ for more information on the scaling of gradients. + If clipping gradients, the gradients will not have been clipped yet. + Args: optimizer: Current optimizer being used. optimizer_idx: Index of the current optimizer being used. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cfac84be1367b..246fd89be5a1a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -576,6 +576,8 @@ def __to_tensor(self, value: numbers.Number) -> torch.Tensor: def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``. + If clipping gradients, the gradients will not have been clipped yet. + Args: grad_norm_dict: Dictionary containing current grad norm metrics diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 6447c65ecd131..d176a479971df 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -109,8 +109,7 @@ def optimizer_step( f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) closure_result = closure() - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 01c8661232f93..27ac384d25303 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -63,8 +63,7 @@ def optimizer_step( f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) closure_result = closure() - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py index 51e80afed3609..80a7f06ad6688 100644 --- a/pytorch_lightning/plugins/precision/ipu_precision.py +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -52,8 +52,7 @@ def optimizer_step( f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) closure_result = closure() - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index c80e648eff5c4..fe4a840b5337c 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -85,8 +85,7 @@ def optimizer_step( closure_result = closure() # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. self.scaler.unscale_(optimizer) - if isinstance(model, pl.LightningModule): - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 8ba8b1a1872ae..f1ebbf58d8326 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -111,6 +111,27 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k """ tensor.backward(*args, **kwargs) + def _after_closure( + self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int + ) -> None: + """Utility to share some code after the closure has been run.""" + if not isinstance(model, pl.LightningModule): + # none of this applies to Lite + return + trainer = model.trainer + assert trainer is not None + trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + # TODO: this is done for the entire model but should be changed to per-optimizer + if optimizer_idx == 0: + self._track_grad_norm(trainer) + self._clip_gradients( + model, + optimizer, + optimizer_idx, + trainer.gradient_clip_val, + gradient_clip_algorithm=trainer.gradient_clip_algorithm, + ) + def _wrap_closure( self, model: "pl.LightningModule", @@ -125,7 +146,7 @@ def _wrap_closure( consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. """ closure_result = closure() - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + self._after_closure(model, optimizer, optimizer_idx) return closure_result def optimizer_step( diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index a14f3af4d9c17..c482e8a83d7b6 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -71,7 +71,13 @@ def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, p assert isinstance(trainer.precision_plugin, plugin_cls) -class GradientUnscaleBoringModel(BoringModel): +class TestClippingOptimizer(torch.optim.SGD): + def step(self, *args, pl_module=None): + pl_module.check_grads_clipped() + return super().step(*args) + + +class TestPrecisionModel(BoringModel): # sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step def on_after_backward(self) -> None: # check grads are scaled @@ -92,6 +98,12 @@ def check_grads_unscaled(self, optimizer=None): for actual, expected in zip(grads, self.original_grads): torch.testing.assert_allclose(actual, expected) + def check_grads_clipped(self): + parameters = list(self.parameters()) + assert len(parameters) == len(self.clipped_parameters) + for actual, expected in zip(parameters, self.clipped_parameters): + torch.testing.assert_allclose(actual.grad, expected.grad) + def on_before_optimizer_step(self, optimizer, *_): self.check_grads_unscaled(optimizer) # manually clip @@ -103,24 +115,28 @@ def on_before_optimizer_step(self, optimizer, *_): clip_val = self.trainer.gradient_clip_val torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val) + def log_grad_norm(self, grad_norm_dict): + self.check_grads_unscaled() + assert len(grad_norm_dict) + def configure_gradient_clipping(self, *args, **kwargs): # let lightning clip super().configure_gradient_clipping(*args, **kwargs) # check clipping worked as expected - parameters = list(self.parameters()) - assert len(parameters) == len(self.clipped_parameters) - for actual, expected in zip(parameters, self.clipped_parameters): - torch.testing.assert_allclose(actual.grad, expected.grad) + self.check_grads_clipped() - def log_grad_norm(self, grad_norm_dict): - self.check_grads_unscaled() - assert len(grad_norm_dict) + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, closure, **_): + # pass self as a kwarg + optimizer.step(closure, pl_module=self) + + def configure_optimizers(self): + return TestClippingOptimizer(self.layer.parameters(), lr=0.1) @RunIf(min_gpus=2) @pytest.mark.parametrize("accum", [1, 2]) def test_amp_gradient_unscale(tmpdir, accum: int): - model = GradientUnscaleBoringModel() + model = TestPrecisionModel() trainer = Trainer( max_epochs=2, @@ -137,6 +153,7 @@ def test_amp_gradient_unscale(tmpdir, accum: int): gradient_clip_algorithm="value", log_every_n_steps=1, accumulate_grad_batches=accum, + enable_progress_bar=False, ) trainer.fit(model) @@ -200,7 +217,6 @@ def training_step(self, batch, batch_idx): @RunIf(min_gpus=2, amp_apex=True) @pytest.mark.parametrize("amp_level", ["O2"]) def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir): - trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True,