@@ -918,13 +918,13 @@ def test_gradient_clipping_by_value(tmpdir):
918918
919919 model = BoringModel ()
920920
921- grad_clip_val = 0.0001
921+ grad_clip_val = 1e-10
922922 trainer = Trainer (
923- max_steps = 10 ,
923+ max_steps = 1 ,
924924 max_epochs = 1 ,
925925 gradient_clip_val = grad_clip_val ,
926926 gradient_clip_algorithm = 'value' ,
927- default_root_dir = tmpdir ,
927+ default_root_dir = tmpdir
928928 )
929929
930930 trainer .train_loop .old_training_step_and_backward = trainer .train_loop .training_step_and_backward
@@ -938,8 +938,8 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
938938 parameters = model .parameters ()
939939 grad_max_list = [torch .max (p .grad .detach ().abs ()) for p in parameters ]
940940 grad_max = torch .max (torch .stack (grad_max_list ))
941- assert round (grad_max .item (), 6 ) <= grad_clip_val , \
942- f"Gradient max value { grad_max } > grad_clip_val { grad_clip_val } ."
941+ assert abs (grad_max .item () - grad_clip_val ) < 1e-11 , \
942+ f"Gradient max value { grad_max } != grad_clip_val { grad_clip_val } ."
943943
944944 return ret_val
945945
@@ -996,9 +996,9 @@ def test_gradient_clipping_by_value_fp16(tmpdir):
996996 tutils .reset_seed ()
997997
998998 model = BoringModel ()
999- grad_clip_val = 0.0001
999+ grad_clip_val = 1e-10
10001000 trainer = Trainer (
1001- max_steps = 10 ,
1001+ max_steps = 1 ,
10021002 max_epochs = 1 ,
10031003 precision = 16 ,
10041004 gpus = 1 ,
@@ -1016,9 +1016,10 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
10161016 # test that gradient is clipped correctly
10171017 ret_val = trainer .train_loop .old_training_step_and_backward (split_batch , batch_idx , opt_idx , optimizer , hiddens )
10181018 parameters = model .parameters ()
1019- grad_max = torch .max (torch .stack ([p .grad .detach () for p in parameters ]))
1020- assert round (grad_max .item (), 6 ) <= grad_clip_val , \
1021- f"Gradient max value { grad_max } > grad_clip_val { grad_clip_val } ."
1019+ grad_max_list = [torch .max (p .grad .detach ().abs ()) for p in parameters ]
1020+ grad_max = torch .max (torch .stack (grad_max_list ))
1021+ assert abs (grad_max .item () - grad_clip_val ) < 1e-11 , \
1022+ f"Gradient max value { grad_max } != grad_clip_val { grad_clip_val } ."
10221023
10231024 return ret_val
10241025
0 commit comments