Skip to content

Commit 2f6313b

Browse files
committed
add predict
1 parent e20ca20 commit 2f6313b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/trainer/test_dataloaders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
636636

637637
def test_warning_with_iterable_dataset_and_len(tmpdir):
638638
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
639-
model = EvalModelTemplate()
639+
model = BoringModel()
640640
original_dataset = model.train_dataloader().dataset
641641

642642
class IterableWithoutLen(IterableDataset):
@@ -660,6 +660,8 @@ def __len__(self):
660660
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
661661
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
662662
trainer.test(model, test_dataloaders=[dataloader])
663+
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
664+
trainer.predict(model, dataloaders=[dataloader])
663665

664666
# without __len__ defined
665667
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
@@ -669,6 +671,7 @@ def __len__(self):
669671
trainer.validate(model, val_dataloaders=dataloader)
670672
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
671673
trainer.test(model, test_dataloaders=dataloader)
674+
trainer.predict(model, dataloaders=dataloader)
672675

673676

674677
@RunIf(min_gpus=2)

0 commit comments

Comments
 (0)