2626from pytorch_lightning .utilities import _TPU_AVAILABLE
2727from pytorch_lightning .utilities .exceptions import MisconfigurationException
2828from tests .base import EvalModelTemplate
29+ from tests .helpers import BoringModel
2930from tests .helpers .datasets import TrialMNIST
3031from tests .helpers .utils import pl_multi_process_test
3132
@@ -55,7 +56,7 @@ def test_model_tpu_cores_1(tmpdir):
5556 limit_val_batches = 0.4 ,
5657 )
5758
58- model = EvalModelTemplate ()
59+ model = BoringModel ()
5960 tpipes .run_model_test (trainer_options , model , on_gpu = False , with_hpc = False )
6061
6162
@@ -73,7 +74,7 @@ def test_model_tpu_index(tmpdir, tpu_core):
7374 limit_val_batches = 0.4 ,
7475 )
7576
76- model = EvalModelTemplate ()
77+ model = BoringModel ()
7778 tpipes .run_model_test (trainer_options , model , on_gpu = False , with_hpc = False )
7879 assert torch_xla ._XLAC ._xla_get_default_device () == f'xla:{ tpu_core } '
7980
@@ -113,7 +114,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
113114 limit_val_batches = 0.4 ,
114115 )
115116
116- model = EvalModelTemplate ()
117+ model = BoringModel ()
117118 tpipes .run_model_test (trainer_options , model , on_gpu = False )
118119 assert os .environ .get ('XLA_USE_BF16' ) == str (1 ), "XLA_USE_BF16 was not set in environment variables"
119120
@@ -133,7 +134,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
133134 limit_val_batches = 0.2 ,
134135 )
135136
136- model = EvalModelTemplate ()
137+ model = BoringModel ()
137138 tpipes .run_model_test (trainer_options , model , on_gpu = False )
138139 assert torch_xla ._XLAC ._xla_get_default_device () == f'xla:{ tpu_core } '
139140 assert os .environ .get ('XLA_USE_BF16' ) == str (1 ), "XLA_USE_BF16 was not set in environment variables"
@@ -192,19 +193,25 @@ def test_tpu_grad_norm(tmpdir):
192193 gradient_clip_val = 0.1 ,
193194 )
194195
195- model = EvalModelTemplate ()
196+ model = BoringModel ()
196197 tpipes .run_model_test (trainer_options , model , on_gpu = False , with_hpc = False )
197198
198199
199200@pytest .mark .skipif (not _TPU_AVAILABLE , reason = "test requires TPU machine" )
200201@pl_multi_process_test
201202def test_dataloaders_passed_to_fit (tmpdir ):
202203 """Test if dataloaders passed to trainer works on TPU"""
203-
204- model = EvalModelTemplate ()
205-
206- trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , tpu_cores = 8 )
207- trainer .fit (model , train_dataloader = model .train_dataloader (), val_dataloaders = model .val_dataloader ())
204+ model = BoringModel ()
205+ trainer = Trainer (
206+ default_root_dir = tmpdir ,
207+ max_epochs = 1 ,
208+ tpu_cores = 8 ,
209+ )
210+ trainer .fit (
211+ model ,
212+ train_dataloader = model .train_dataloader (),
213+ val_dataloaders = model .val_dataloader (),
214+ )
208215 assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
209216
210217
0 commit comments