Skip to content

Commit 3094220

Browse files
committed
tests
1 parent 4dbee0a commit 3094220

File tree

5 files changed

+20
-11
lines changed

5 files changed

+20
-11
lines changed

tests/accelerators/test_ddp_spawn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from pytorch_lightning.trainer import Trainer
2222
from pytorch_lightning.trainer.states import TrainerState
2323
from tests.base import EvalModelTemplate
24+
from tests.helpers import BoringModel
25+
from tests.helpers.datamodules import ClassifDataModule
26+
from tests.helpers.simple_models import ClassificationModel
2427

2528

2629
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -37,8 +40,9 @@ def test_multi_gpu_early_stop_ddp_spawn(tmpdir):
3740
accelerator='ddp_spawn',
3841
)
3942

40-
model = EvalModelTemplate()
41-
tpipes.run_model_test(trainer_options, model)
43+
dm = ClassifDataModule()
44+
model = ClassificationModel()
45+
tpipes.run_model_test(trainer_options, model, dm)
4246

4347

4448
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -55,7 +59,7 @@ def test_multi_gpu_model_ddp_spawn(tmpdir):
5559
progress_bar_refresh_rate=0,
5660
)
5761

58-
model = EvalModelTemplate()
62+
model = BoringModel()
5963

6064
tpipes.run_model_test(trainer_options, model)
6165

tests/accelerators/test_dp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from pytorch_lightning.callbacks import EarlyStopping
2121
from pytorch_lightning.core import memory
2222
from tests.base import EvalModelTemplate
23+
from tests.helpers import BoringModel
24+
from tests.helpers.datamodules import ClassifDataModule
25+
from tests.helpers.simple_models import ClassificationModel
2326

2427
PRETEND_N_OF_GPUS = 16
2528

@@ -39,8 +42,9 @@ def test_multi_gpu_early_stop_dp(tmpdir):
3942
accelerator='dp',
4043
)
4144

42-
model = EvalModelTemplate()
43-
tpipes.run_model_test(trainer_options, model)
45+
dm = ClassifDataModule()
46+
model = ClassificationModel()
47+
tpipes.run_model_test(trainer_options, model, dm)
4448

4549

4650
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -57,7 +61,7 @@ def test_multi_gpu_model_dp(tmpdir):
5761
progress_bar_refresh_rate=0,
5862
)
5963

60-
model = EvalModelTemplate()
64+
model = BoringModel()
6165

6266
tpipes.run_model_test(trainer_options, model)
6367

tests/helpers/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def run_model_test_without_loggers(
4949

5050
def run_model_test(
5151
trainer_options,
52-
model,
52+
model: LightningModule,
5353
data: LightningDataModule = None,
5454
on_gpu: bool = True,
5555
version=None,

tests/models/test_gpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytorch_lightning.utilities import device_parser
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from tests.helpers import BoringModel
27+
from tests.helpers.datamodules import ClassifDataModule
2728
from tests.helpers.simple_models import ClassificationModel
2829

2930
PRETEND_N_OF_GPUS = 16
@@ -42,8 +43,9 @@ def test_multi_gpu_none_backend(tmpdir):
4243
gpus=2,
4344
)
4445

46+
dm = ClassifDataModule()
4547
model = ClassificationModel()
46-
tpipes.run_model_test(trainer_options, model)
48+
tpipes.run_model_test(trainer_options, model, dm)
4749

4850

4951
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")

tests/models/test_tpu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,11 @@ def test_tpu_grad_norm(tmpdir):
222222
@pl_multi_process_test
223223
def test_dataloaders_passed_to_fit(tmpdir):
224224
"""Test if dataloaders passed to trainer works on TPU"""
225-
226225
tutils.reset_seed()
227226
model = BoringModel()
228227

229-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8)
230-
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
228+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8,)
229+
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader(),)
231230
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
232231

233232

0 commit comments

Comments
 (0)