|
25 | 25 |
|
26 | 26 | import tests.base.develop_pipelines as tpipes |
27 | 27 | import tests.base.develop_utils as tutils |
28 | | -from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything |
| 28 | +from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything |
29 | 29 | from pytorch_lightning.callbacks import ModelCheckpoint |
30 | 30 | from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST |
31 | 31 |
|
@@ -139,6 +139,8 @@ def test_callbacks_references_resume_from_checkpoint(tmpdir): |
139 | 139 | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") |
140 | 140 | def test_running_test_pretrained_model_distrib_dp(tmpdir): |
141 | 141 | """Verify `test()` on pretrained model.""" |
| 142 | + __import__("pdb").set_trace() |
| 143 | + |
142 | 144 | tutils.set_random_master_port() |
143 | 145 |
|
144 | 146 | model = EvalModelTemplate() |
@@ -183,7 +185,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): |
183 | 185 | dataloaders = [dataloaders] |
184 | 186 |
|
185 | 187 | for dataloader in dataloaders: |
186 | | - tpipes.run_prediction(dataloader, pretrained_model) |
| 188 | + tpipes.run_prediction(pretrained_model, dataloader) |
187 | 189 |
|
188 | 190 |
|
189 | 191 | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") |
@@ -234,7 +236,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): |
234 | 236 | dataloaders = [dataloaders] |
235 | 237 |
|
236 | 238 | for dataloader in dataloaders: |
237 | | - tpipes.run_prediction(dataloader, pretrained_model) |
| 239 | + tpipes.run_prediction(pretrained_model, dataloader) |
238 | 240 |
|
239 | 241 |
|
240 | 242 | def test_running_test_pretrained_model_cpu(tmpdir): |
@@ -376,7 +378,7 @@ def assert_good_acc(): |
376 | 378 | dp_model.eval() |
377 | 379 |
|
378 | 380 | dataloader = trainer.train_dataloader |
379 | | - tpipes.run_prediction(dataloader, dp_model, dp=True) |
| 381 | + tpipes.run_prediction(dp_model, dataloader, dp=True) |
380 | 382 |
|
381 | 383 | # new model |
382 | 384 | model = EvalModelTemplate(**hparams) |
|
0 commit comments