From 858507132f6543a90907d9f0a52295ec45589bdf Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 00:09:38 +0800 Subject: [PATCH 1/8] An attempt to fix `gradient_clip_algorithm` problem (#6920) Also add a temporay workaround to #6807 --- pytorch_lightning/accelerators/accelerator.py | 4 +++- pytorch_lightning/accelerators/tpu.py | 23 +++++++++++++++++-- .../connectors/training_trick_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 5 +++- tests/trainer/test_trainer.py | 17 +++++++------- 6 files changed, 40 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e97450fdbd885..bdf415bd78a35 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -322,7 +322,9 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm) + self.precision_plugin.clip_gradients( + self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: """Hook to do something on the end of an training epoch diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 087f6df7a1c6a..1e00632810746 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -19,7 +19,7 @@ from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin -from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _XLA_AVAILABLE: @@ -56,7 +56,26 @@ def run_optimizer_step( ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + # todo: Add support for backward with all_gather + if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: + return xm.all_gather(tensor).view(-1, *tensor.shape) + return tensor + + def clip_gradients( + self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM + ) -> None: + assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, "Only NORM gradient clipping is supported on TPU for now" model = self.lightning_module parameters = model.parameters() diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 899ffbf56e8fd..69efe1cc83c3f 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -39,7 +39,7 @@ def on_trainer_init( f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}" ) self.trainer.gradient_clip_val = gradient_clip_val - self.trainer.gradient_clip_algorithm = gradient_clip_algorithm + self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) # gradient norm tracking if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e6bf36df92a01..38226734e4e57 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -633,11 +633,12 @@ def run_train(self) -> None: if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() - except (RuntimeError, AssertionError): + except (RuntimeError, AssertionError) as e: # if an exception is raised, the finally block is executed and can hide the actual exception # that was initially raised if `on_train_end` also raises an exception. we want to avoid that # for assertions and other runtime errors so we aren't misled while debugging print_exc() + raise e finally: # hook self.train_loop.on_train_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d2749733812b3..0d2c4dca75fe5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -381,7 +381,10 @@ def track_and_norm_grad(self, optimizer): grad_norm_dic = self._track_gradient_norm() # clip gradients - self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val) + self.trainer.accelerator.clip_gradients( + optimizer, self.trainer.gradient_clip_val, + gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + ) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 447ed2b41b8d6..ab22564689429 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -918,13 +918,13 @@ def test_gradient_clipping_by_value(tmpdir): model = BoringModel() - grad_clip_val = 0.0001 + grad_clip_val = 1e-5 trainer = Trainer( max_steps=10, max_epochs=1, gradient_clip_val=grad_clip_val, gradient_clip_algorithm='value', - default_root_dir=tmpdir, + default_root_dir=tmpdir ) trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward @@ -938,8 +938,8 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde parameters = model.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - assert round(grad_max.item(), 6) <= grad_clip_val, \ - f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \ + f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." return ret_val @@ -996,7 +996,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): tutils.reset_seed() model = BoringModel() - grad_clip_val = 0.0001 + grad_clip_val = 1e-5 trainer = Trainer( max_steps=10, max_epochs=1, @@ -1016,9 +1016,10 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde # test that gradient is clipped correctly ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() - grad_max = torch.max(torch.stack([p.grad.detach() for p in parameters])) - assert round(grad_max.item(), 6) <= grad_clip_val, \ - f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] + grad_max = torch.max(torch.stack(grad_max_list)) + assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \ + f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." return ret_val From 17a198e94f5a9eae096f6226c8d61f232d681390 Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 00:21:03 +0800 Subject: [PATCH 2/8] Fix the PEP8 compliance issue --- pytorch_lightning/accelerators/tpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 1e00632810746..bb12af2d05a70 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -75,7 +75,8 @@ def clip_gradients( self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM ) -> None: - assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, "Only NORM gradient clipping is supported on TPU for now" + assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \ + "Only NORM gradient clipping is supported on TPU for now" model = self.lightning_module parameters = model.parameters() From 338769a65c52442a2cfcd58a3ae8a5b28cb7146a Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 00:23:57 +0800 Subject: [PATCH 3/8] Fix the mistake from rebasing --- pytorch_lightning/accelerators/tpu.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index bb12af2d05a70..e51393a53122f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -56,21 +56,6 @@ def run_optimizer_step( ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - # todo: Add support for backward with all_gather - if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: - return xm.all_gather(tensor).view(-1, *tensor.shape) - return tensor - def clip_gradients( self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM From bf1bd4607004b49d2e4f932e3a7c79c08cf08b59 Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 00:34:02 +0800 Subject: [PATCH 4/8] Revert changes regarding to #6807 --- pytorch_lightning/trainer/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 38226734e4e57..7ebb1a9899d51 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -638,7 +638,6 @@ def run_train(self) -> None: # that was initially raised if `on_train_end` also raises an exception. we want to avoid that # for assertions and other runtime errors so we aren't misled while debugging print_exc() - raise e finally: # hook self.train_loop.on_train_end() From 218ecec214361f49324877fb2312f576a9ad56ea Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 00:38:04 +0800 Subject: [PATCH 5/8] Remove an unused variable --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7ebb1a9899d51..e6bf36df92a01 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -633,7 +633,7 @@ def run_train(self) -> None: if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() - except (RuntimeError, AssertionError) as e: + except (RuntimeError, AssertionError): # if an exception is raised, the finally block is executed and can hide the actual exception # that was initially raised if `on_train_end` also raises an exception. we want to avoid that # for assertions and other runtime errors so we aren't misled while debugging From a3007ad5a03a967e711b3eccf117cd926e6ac41b Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 00:57:10 +0800 Subject: [PATCH 6/8] Expect the TPU test case to fail --- tests/models/test_tpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 6409f2ef4bcbf..b9c0111d313a5 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -222,7 +222,7 @@ def test_tpu_grad_norm(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_tpu_clip_grad_by_value(tmpdir): - """Test if clip_gradients by value works on TPU.""" + """Test if clip_gradients by value works on TPU. (It should not.)""" tutils.reset_seed() trainer_options = dict( default_root_dir=tmpdir, @@ -236,7 +236,8 @@ def test_tpu_clip_grad_by_value(tmpdir): ) model = BoringModel() - tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + with pytest.raises(AssertionError): + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) @RunIf(tpu=True) From 81cab3b61130542cf7063f72062f562072f9231e Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 22:51:06 +0800 Subject: [PATCH 7/8] Make the possibility of false positive even lower --- tests/trainer/test_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ab22564689429..6fb06760660db 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -918,7 +918,7 @@ def test_gradient_clipping_by_value(tmpdir): model = BoringModel() - grad_clip_val = 1e-5 + grad_clip_val = 1e-10 trainer = Trainer( max_steps=10, max_epochs=1, @@ -938,7 +938,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde parameters = model.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \ + assert abs(grad_max.item() - grad_clip_val) < 1e-11, \ f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." return ret_val @@ -996,7 +996,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): tutils.reset_seed() model = BoringModel() - grad_clip_val = 1e-5 + grad_clip_val = 1e-10 trainer = Trainer( max_steps=10, max_epochs=1, @@ -1018,7 +1018,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde parameters = model.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \ + assert abs(grad_max.item() - grad_clip_val) < 1e-11, \ f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." return ret_val From 3d3f7cc99047514dc5ff934ca6993fe932cd454a Mon Sep 17 00:00:00 2001 From: Ceshine Lee Date: Sat, 10 Apr 2021 23:10:18 +0800 Subject: [PATCH 8/8] Only train one step in test cases --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6fb06760660db..71f4b9371335c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -920,7 +920,7 @@ def test_gradient_clipping_by_value(tmpdir): grad_clip_val = 1e-10 trainer = Trainer( - max_steps=10, + max_steps=1, max_epochs=1, gradient_clip_val=grad_clip_val, gradient_clip_algorithm='value', @@ -998,7 +998,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): model = BoringModel() grad_clip_val = 1e-10 trainer = Trainer( - max_steps=10, + max_steps=1, max_epochs=1, precision=16, gpus=1,