Skip to content

Commit e1241dd

Browse files
committed
tests
1 parent 02a6144 commit e1241dd

File tree

5 files changed

+20
-11
lines changed

5 files changed

+20
-11
lines changed

tests/accelerators/legacy/test_ddp_spawn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from pytorch_lightning.trainer import Trainer
2222
from pytorch_lightning.trainer.states import TrainerState
2323
from tests.base import EvalModelTemplate
24+
from tests.helpers import BoringModel
25+
from tests.helpers.datamodules import ClassifDataModule
26+
from tests.helpers.simple_models import ClassificationModel
2427

2528

2629
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -38,8 +41,9 @@ def test_multi_gpu_early_stop_ddp_spawn(tmpdir):
3841
accelerator='ddp_spawn',
3942
)
4043

41-
model = EvalModelTemplate()
42-
tpipes.run_model_test(trainer_options, model)
44+
dm = ClassifDataModule()
45+
model = ClassificationModel()
46+
tpipes.run_model_test(trainer_options, model, dm)
4347

4448

4549
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -56,7 +60,7 @@ def test_multi_gpu_model_ddp_spawn(tmpdir):
5660
progress_bar_refresh_rate=0,
5761
)
5862

59-
model = EvalModelTemplate()
63+
model = BoringModel()
6064

6165
tpipes.run_model_test(trainer_options, model)
6266

tests/accelerators/legacy/test_dp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from pytorch_lightning.callbacks import EarlyStopping
2121
from pytorch_lightning.core import memory
2222
from tests.base import EvalModelTemplate
23+
from tests.helpers import BoringModel
24+
from tests.helpers.datamodules import ClassifDataModule
25+
from tests.helpers.simple_models import ClassificationModel
2326

2427
PRETEND_N_OF_GPUS = 16
2528

@@ -39,8 +42,9 @@ def test_multi_gpu_early_stop_dp(tmpdir):
3942
accelerator='dp',
4043
)
4144

42-
model = EvalModelTemplate()
43-
tpipes.run_model_test(trainer_options, model)
45+
dm = ClassifDataModule()
46+
model = ClassificationModel()
47+
tpipes.run_model_test(trainer_options, model, dm)
4448

4549

4650
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -57,7 +61,7 @@ def test_multi_gpu_model_dp(tmpdir):
5761
progress_bar_refresh_rate=0,
5862
)
5963

60-
model = EvalModelTemplate()
64+
model = BoringModel()
6165

6266
tpipes.run_model_test(trainer_options, model)
6367

tests/helpers/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def run_model_test_without_loggers(
5454

5555
def run_model_test(
5656
trainer_options,
57-
model,
57+
model: LightningModule,
5858
data: LightningDataModule = None,
5959
on_gpu: bool = True,
6060
version=None,

tests/models/test_gpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.utilities import device_parser
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from tests.helpers import BoringModel
28+
from tests.helpers.datamodules import ClassifDataModule
2829
from tests.helpers.simple_models import ClassificationModel
2930

3031
PRETEND_N_OF_GPUS = 16
@@ -43,8 +44,9 @@ def test_multi_gpu_none_backend(tmpdir):
4344
gpus=2,
4445
)
4546

47+
dm = ClassifDataModule()
4648
model = ClassificationModel()
47-
tpipes.run_model_test(trainer_options, model)
49+
tpipes.run_model_test(trainer_options, model, dm)
4850

4951

5052
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")

tests/models/test_tpu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,11 @@ def test_tpu_grad_norm(tmpdir):
219219
@pl_multi_process_test
220220
def test_dataloaders_passed_to_fit(tmpdir):
221221
"""Test if dataloaders passed to trainer works on TPU"""
222-
223222
tutils.reset_seed()
224223
model = BoringModel()
225224

226-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8)
227-
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
225+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8,)
226+
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader(),)
228227
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
229228

230229

0 commit comments

Comments
 (0)