diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e97450fdbd885..bdf415bd78a35 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -322,7 +322,9 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm) + self.precision_plugin.clip_gradients( + self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: """Hook to do something on the end of an training epoch diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 087f6df7a1c6a..e51393a53122f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -19,7 +19,7 @@ from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin -from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _XLA_AVAILABLE: @@ -56,7 +56,12 @@ def run_optimizer_step( ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): + def clip_gradients( + self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM + ) -> None: + assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \ + "Only NORM gradient clipping is supported on TPU for now" model = self.lightning_module parameters = model.parameters() diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 899ffbf56e8fd..69efe1cc83c3f 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -39,7 +39,7 @@ def on_trainer_init( f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}" ) self.trainer.gradient_clip_val = gradient_clip_val - self.trainer.gradient_clip_algorithm = gradient_clip_algorithm + self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) # gradient norm tracking if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d2749733812b3..0d2c4dca75fe5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -381,7 +381,10 @@ def track_and_norm_grad(self, optimizer): grad_norm_dic = self._track_gradient_norm() # clip gradients - self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val) + self.trainer.accelerator.clip_gradients( + optimizer, self.trainer.gradient_clip_val, + gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + ) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 6409f2ef4bcbf..b9c0111d313a5 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -222,7 +222,7 @@ def test_tpu_grad_norm(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_tpu_clip_grad_by_value(tmpdir): - """Test if clip_gradients by value works on TPU.""" + """Test if clip_gradients by value works on TPU. (It should not.)""" tutils.reset_seed() trainer_options = dict( default_root_dir=tmpdir, @@ -236,7 +236,8 @@ def test_tpu_clip_grad_by_value(tmpdir): ) model = BoringModel() - tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + with pytest.raises(AssertionError): + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) @RunIf(tpu=True) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 447ed2b41b8d6..71f4b9371335c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -918,13 +918,13 @@ def test_gradient_clipping_by_value(tmpdir): model = BoringModel() - grad_clip_val = 0.0001 + grad_clip_val = 1e-10 trainer = Trainer( - max_steps=10, + max_steps=1, max_epochs=1, gradient_clip_val=grad_clip_val, gradient_clip_algorithm='value', - default_root_dir=tmpdir, + default_root_dir=tmpdir ) trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward @@ -938,8 +938,8 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde parameters = model.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - assert round(grad_max.item(), 6) <= grad_clip_val, \ - f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + assert abs(grad_max.item() - grad_clip_val) < 1e-11, \ + f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." return ret_val @@ -996,9 +996,9 @@ def test_gradient_clipping_by_value_fp16(tmpdir): tutils.reset_seed() model = BoringModel() - grad_clip_val = 0.0001 + grad_clip_val = 1e-10 trainer = Trainer( - max_steps=10, + max_steps=1, max_epochs=1, precision=16, gpus=1, @@ -1016,9 +1016,10 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde # test that gradient is clipped correctly ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() - grad_max = torch.max(torch.stack([p.grad.detach() for p in parameters])) - assert round(grad_max.item(), 6) <= grad_clip_val, \ - f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] + grad_max = torch.max(torch.stack(grad_max_list)) + assert abs(grad_max.item() - grad_clip_val) < 1e-11, \ + f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." return ret_val