2727from tests .helpers import BoringModel
2828
2929
30+ class AMPTestModel (BoringModel ):
31+
32+ def training_step (self , batch , batch_idx ):
33+ assert torch .is_autocast_enabled ()
34+ output = self (batch )
35+ assert output .dtype == torch .float16
36+ loss = self .loss (batch , output )
37+ return {"loss" : loss }
38+
39+
3040@pytest .mark .skip (reason = 'dp + amp not supported currently' ) # TODO
3141@pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires GPU machine" )
3242def test_amp_single_gpu_dp (tmpdir ):
@@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
4151 precision = 16 ,
4252 )
4353
44- model = BoringModel ()
54+ model = AMPTestModel ()
4555 # tutils.run_model_test(trainer_options, model)
4656 trainer .fit (model )
4757
@@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
6070 precision = 16 ,
6171 )
6272
63- model = BoringModel ()
73+ model = AMPTestModel ()
6474 # tutils.run_model_test(trainer_options, model)
6575 trainer .fit (model )
66-
6776 assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
6877
6978
@@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
8190 precision = 16 ,
8291 )
8392
84- model = BoringModel ()
93+ model = AMPTestModel ()
8594 # tutils.run_model_test(trainer_options, model)
8695 trainer .fit (model )
8796
@@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
100109 precision = 16 ,
101110 )
102111
103- model = BoringModel ()
112+ model = AMPTestModel ()
104113 # tutils.run_model_test(trainer_options, model)
105114 trainer .fit (model )
106-
107115 assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
108116
109117
@@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
122130 # simulate setting slurm flags
123131 tutils .set_random_master_port ()
124132
125- model = BoringModel ()
133+ model = AMPTestModel ()
126134
127135 # exp file to get meta
128136 logger = tutils .get_default_logger (tmpdir )
0 commit comments