|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import pytest |
| 15 | +import torch |
15 | 16 |
|
16 | | -from pytorch_lightning import Trainer |
| 17 | +from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
17 | 18 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
18 | | -from tests.helpers import BoringModel |
| 19 | +from tests.helpers import BoringModel, RandomDataset |
19 | 20 |
|
20 | 21 |
|
21 | 22 | def test_wrong_train_setting(tmpdir): |
@@ -101,3 +102,48 @@ def test_val_loop_config(tmpdir): |
101 | 102 | model = BoringModel() |
102 | 103 | model.validation_step = None |
103 | 104 | trainer.validate(model) |
| 105 | + |
| 106 | + |
| 107 | +@pytest.mark.parametrize("datamodule", [False, True]) |
| 108 | +def test_trainer_predict_verify_config(tmpdir, datamodule): |
| 109 | + |
| 110 | + class TestModel(LightningModule): |
| 111 | + |
| 112 | + def __init__(self): |
| 113 | + super().__init__() |
| 114 | + self.layer = torch.nn.Linear(32, 2) |
| 115 | + |
| 116 | + def forward(self, x): |
| 117 | + return self.layer(x) |
| 118 | + |
| 119 | + class TestLightningDataModule(LightningDataModule): |
| 120 | + |
| 121 | + def __init__(self, dataloaders): |
| 122 | + super().__init__() |
| 123 | + self._dataloaders = dataloaders |
| 124 | + |
| 125 | + def test_dataloader(self): |
| 126 | + return self._dataloaders |
| 127 | + |
| 128 | + def predict_dataloader(self): |
| 129 | + return self._dataloaders |
| 130 | + |
| 131 | + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] |
| 132 | + |
| 133 | + model = TestModel() |
| 134 | + |
| 135 | + trainer = Trainer(default_root_dir=tmpdir) |
| 136 | + |
| 137 | + if datamodule: |
| 138 | + datamodule = TestLightningDataModule(dataloaders) |
| 139 | + results = trainer.predict(model, datamodule=datamodule) |
| 140 | + else: |
| 141 | + results = trainer.predict(model, dataloaders=dataloaders) |
| 142 | + |
| 143 | + assert len(results) == 2 |
| 144 | + assert results[0][0].shape == torch.Size([1, 2]) |
| 145 | + |
| 146 | + model.predict_dataloader = None |
| 147 | + |
| 148 | + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): |
| 149 | + trainer.predict(model) |
0 commit comments