Skip to content

Commit d5b9706

Browse files
committed
Fix for dispatcher
1 parent 53cdd58 commit d5b9706

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/models/test_restore.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import tests.base.develop_pipelines as tpipes
2727
import tests.base.develop_utils as tutils
28-
from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything
28+
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything
2929
from pytorch_lightning.callbacks import ModelCheckpoint
3030
from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST
3131

@@ -139,6 +139,8 @@ def test_callbacks_references_resume_from_checkpoint(tmpdir):
139139
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
140140
def test_running_test_pretrained_model_distrib_dp(tmpdir):
141141
"""Verify `test()` on pretrained model."""
142+
__import__("pdb").set_trace()
143+
142144
tutils.set_random_master_port()
143145

144146
model = EvalModelTemplate()
@@ -183,7 +185,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
183185
dataloaders = [dataloaders]
184186

185187
for dataloader in dataloaders:
186-
tpipes.run_prediction(dataloader, pretrained_model)
188+
tpipes.run_prediction(pretrained_model, dataloader)
187189

188190

189191
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -234,7 +236,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
234236
dataloaders = [dataloaders]
235237

236238
for dataloader in dataloaders:
237-
tpipes.run_prediction(dataloader, pretrained_model)
239+
tpipes.run_prediction(pretrained_model, dataloader)
238240

239241

240242
def test_running_test_pretrained_model_cpu(tmpdir):
@@ -376,7 +378,7 @@ def assert_good_acc():
376378
dp_model.eval()
377379

378380
dataloader = trainer.train_dataloader
379-
tpipes.run_prediction(dataloader, dp_model, dp=True)
381+
tpipes.run_prediction(dp_model, dataloader, dp=True)
380382

381383
# new model
382384
model = EvalModelTemplate(**hparams)

0 commit comments

Comments
 (0)