@@ -54,7 +54,7 @@ def val_dataloader(self):
5454 return DataLoader (RandomDataset (32 , 2000 ), batch_size = 32 )
5555
5656
57- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
57+ @RunIf ( tpu = True )
5858@pl_multi_process_test
5959def test_model_tpu_cores_1 (tmpdir ):
6060 """Make sure model trains on TPU."""
@@ -73,7 +73,7 @@ def test_model_tpu_cores_1(tmpdir):
7373
7474
7575@pytest .mark .parametrize ('tpu_core' , [1 , 5 ])
76- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
76+ @RunIf ( tpu = True )
7777@pl_multi_process_test
7878def test_model_tpu_index (tmpdir , tpu_core ):
7979 """Make sure model trains on TPU."""
@@ -92,7 +92,7 @@ def test_model_tpu_index(tmpdir, tpu_core):
9292 assert torch_xla ._XLAC ._xla_get_default_device () == f'xla:{ tpu_core } '
9393
9494
95- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
95+ @RunIf ( tpu = True )
9696@pl_multi_process_test
9797def test_model_tpu_cores_8 (tmpdir ):
9898 """Make sure model trains on TPU."""
@@ -111,7 +111,7 @@ def test_model_tpu_cores_8(tmpdir):
111111 tpipes .run_model_test (trainer_options , model , on_gpu = False , with_hpc = False , min_acc = 0.05 )
112112
113113
114- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
114+ @RunIf ( tpu = True )
115115@pl_multi_process_test
116116def test_model_16bit_tpu_cores_1 (tmpdir ):
117117 """Make sure model trains on TPU."""
@@ -132,7 +132,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
132132
133133
134134@pytest .mark .parametrize ('tpu_core' , [1 , 5 ])
135- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
135+ @RunIf ( tpu = True )
136136@pl_multi_process_test
137137def test_model_16bit_tpu_index (tmpdir , tpu_core ):
138138 """Make sure model trains on TPU."""
@@ -153,7 +153,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
153153 assert os .environ .get ('XLA_USE_BF16' ) == str (1 ), "XLA_USE_BF16 was not set in environment variables"
154154
155155
156- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
156+ @RunIf ( tpu = True )
157157@pl_multi_process_test
158158def test_model_16bit_tpu_cores_8 (tmpdir ):
159159 """Make sure model trains on TPU."""
@@ -173,7 +173,7 @@ def test_model_16bit_tpu_cores_8(tmpdir):
173173 tpipes .run_model_test (trainer_options , model , on_gpu = False , with_hpc = False , min_acc = 0.05 )
174174
175175
176- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
176+ @RunIf ( tpu = True )
177177@pl_multi_process_test
178178def test_model_tpu_early_stop (tmpdir ):
179179 """Test if single TPU core training works"""
@@ -200,7 +200,7 @@ def validation_step(self, *args, **kwargs):
200200 trainer .test (test_dataloaders = DataLoader (RandomDataset (32 , 2000 ), batch_size = 32 ))
201201
202202
203- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
203+ @RunIf ( tpu = True )
204204@pl_multi_process_test
205205def test_tpu_grad_norm (tmpdir ):
206206 """Test if grad_norm works on TPU."""
@@ -219,16 +219,24 @@ def test_tpu_grad_norm(tmpdir):
219219 tpipes .run_model_test (trainer_options , model , on_gpu = False , with_hpc = False )
220220
221221
222- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
222+ @RunIf ( tpu = True )
223223@pl_multi_process_test
224224def test_dataloaders_passed_to_fit (tmpdir ):
225225 """Test if dataloaders passed to trainer works on TPU"""
226226
227227 tutils .reset_seed ()
228228 model = BoringModel ()
229229
230- trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , tpu_cores = 8 )
231- trainer .fit (model , train_dataloader = model .train_dataloader (), val_dataloaders = model .val_dataloader ())
230+ trainer = Trainer (
231+ default_root_dir = tmpdir ,
232+ max_epochs = 1 ,
233+ tpu_cores = 8 ,
234+ )
235+ trainer .fit (
236+ model ,
237+ train_dataloader = model .train_dataloader (),
238+ val_dataloaders = model .val_dataloader (),
239+ )
232240 assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
233241
234242
@@ -237,7 +245,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
237245 [pytest .param (1 , None ), pytest .param (8 , None ),
238246 pytest .param ([1 ], 1 ), pytest .param ([8 ], 8 )],
239247)
240- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires missing TPU" )
248+ @RunIf ( tpu = True )
241249def test_tpu_id_to_be_as_expected (tpu_cores , expected_tpu_id ):
242250 """Test if trainer.tpu_id is set as expected"""
243251 assert Trainer (tpu_cores = tpu_cores ).accelerator_connector .tpu_id == expected_tpu_id
@@ -258,13 +266,13 @@ def test_exception_when_no_tpu_found(tmpdir):
258266
259267
260268@pytest .mark .parametrize ('tpu_cores' , [1 , 8 , [1 ]])
261- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
269+ @RunIf ( tpu = True )
262270def test_distributed_backend_set_when_using_tpu (tmpdir , tpu_cores ):
263271 """Test if distributed_backend is set to `tpu` when tpu_cores is not None"""
264272 assert Trainer (tpu_cores = tpu_cores ).distributed_backend == "tpu"
265273
266274
267- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
275+ @RunIf ( tpu = True )
268276@pl_multi_process_test
269277def test_broadcast_on_tpu ():
270278 """ Checks if an object from the master process is broadcasted to other processes correctly"""
@@ -296,7 +304,7 @@ def test_broadcast(rank):
296304 pytest .param (10 , None , True ),
297305 ],
298306)
299- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
307+ @RunIf ( tpu = True )
300308@pl_multi_process_test
301309def test_tpu_choice (tmpdir , tpu_cores , expected_tpu_id , error_expected ):
302310 if error_expected :
@@ -312,7 +320,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
312320 [pytest .param ('--tpu_cores=8' , {'tpu_cores' : 8 }),
313321 pytest .param ("--tpu_cores=1," , {'tpu_cores' : '1,' })]
314322)
315- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
323+ @RunIf ( tpu = True )
316324@pl_multi_process_test
317325def test_tpu_cores_with_argparse (cli_args , expected ):
318326 """Test passing tpu_cores in command line"""
@@ -327,7 +335,7 @@ def test_tpu_cores_with_argparse(cli_args, expected):
327335 assert Trainer .from_argparse_args (args )
328336
329337
330- @pytest . mark . skipif ( not _TPU_AVAILABLE , reason = "test requires TPU machine" )
338+ @RunIf ( tpu = True )
331339@pl_multi_process_test
332340def test_tpu_reduce ():
333341 """Test tpu spawn reduce operation """
0 commit comments