Skip to content

Commit 24d0295

Browse files
authored
Fix the gradient_clip_algorithm has no effect issue. (#6928)
1 parent fb02972 commit 24d0295

File tree

6 files changed

+29
-17
lines changed

6 files changed

+29
-17
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ def clip_gradients(
325325
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
326326
) -> None:
327327
"""clips all the optimizer parameters to the given value"""
328-
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm)
328+
self.precision_plugin.clip_gradients(
329+
self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm
330+
)
329331

330332
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
331333
"""Hook to do something on the end of an training epoch

pytorch_lightning/accelerators/tpu.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
2020
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
2121
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
22-
from pytorch_lightning.utilities import _XLA_AVAILABLE
22+
from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424

2525
if _XLA_AVAILABLE:
@@ -55,7 +55,12 @@ def run_optimizer_step(
5555
) -> None:
5656
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
5757

58-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
58+
def clip_gradients(
59+
self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0,
60+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
61+
) -> None:
62+
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
63+
"Only NORM gradient clipping is supported on TPU for now"
5964

6065
model = self.lightning_module
6166
parameters = model.parameters()

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def on_trainer_init(
3939
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
4040
)
4141
self.trainer.gradient_clip_val = gradient_clip_val
42-
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm
42+
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
4343

4444
# gradient norm tracking
4545
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,10 @@ def track_and_norm_grad(self, optimizer):
437437
grad_norm_dic = self._track_gradient_norm()
438438

439439
# clip gradients
440-
self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val)
440+
self.trainer.accelerator.clip_gradients(
441+
optimizer, self.trainer.gradient_clip_val,
442+
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
443+
)
441444
self._cur_grad_norm_dict = grad_norm_dic
442445

443446
def _track_gradient_norm(self):

tests/models/test_tpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_tpu_grad_norm(tmpdir):
224224
@RunIf(tpu=True)
225225
@pl_multi_process_test
226226
def test_tpu_clip_grad_by_value(tmpdir):
227-
"""Test if clip_gradients by value works on TPU."""
227+
"""Test if clip_gradients by value works on TPU. (It should not.)"""
228228
tutils.reset_seed()
229229
trainer_options = dict(
230230
default_root_dir=tmpdir,
@@ -238,7 +238,8 @@ def test_tpu_clip_grad_by_value(tmpdir):
238238
)
239239

240240
model = BoringModel()
241-
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
241+
with pytest.raises(AssertionError):
242+
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
242243

243244

244245
@RunIf(tpu=True)

tests/trainer/test_trainer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,13 @@ def test_gradient_clipping_by_value(tmpdir):
918918

919919
model = BoringModel()
920920

921-
grad_clip_val = 0.0001
921+
grad_clip_val = 1e-10
922922
trainer = Trainer(
923-
max_steps=10,
923+
max_steps=1,
924924
max_epochs=1,
925925
gradient_clip_val=grad_clip_val,
926926
gradient_clip_algorithm='value',
927-
default_root_dir=tmpdir,
927+
default_root_dir=tmpdir
928928
)
929929

930930
trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward
@@ -938,8 +938,8 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
938938
parameters = model.parameters()
939939
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
940940
grad_max = torch.max(torch.stack(grad_max_list))
941-
assert round(grad_max.item(), 6) <= grad_clip_val, \
942-
f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ."
941+
assert abs(grad_max.item() - grad_clip_val) < 1e-11, \
942+
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
943943

944944
return ret_val
945945

@@ -996,9 +996,9 @@ def test_gradient_clipping_by_value_fp16(tmpdir):
996996
tutils.reset_seed()
997997

998998
model = BoringModel()
999-
grad_clip_val = 0.0001
999+
grad_clip_val = 1e-10
10001000
trainer = Trainer(
1001-
max_steps=10,
1001+
max_steps=1,
10021002
max_epochs=1,
10031003
precision=16,
10041004
gpus=1,
@@ -1016,9 +1016,10 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
10161016
# test that gradient is clipped correctly
10171017
ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
10181018
parameters = model.parameters()
1019-
grad_max = torch.max(torch.stack([p.grad.detach() for p in parameters]))
1020-
assert round(grad_max.item(), 6) <= grad_clip_val, \
1021-
f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ."
1019+
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
1020+
grad_max = torch.max(torch.stack(grad_max_list))
1021+
assert abs(grad_max.item() - grad_clip_val) < 1e-11, \
1022+
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
10221023

10231024
return ret_val
10241025

0 commit comments

Comments
 (0)