@@ -636,7 +636,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
636636
637637def test_warning_with_iterable_dataset_and_len (tmpdir ):
638638 """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """
639- model = EvalModelTemplate ()
639+ model = BoringModel ()
640640 original_dataset = model .train_dataloader ().dataset
641641
642642 class IterableWithoutLen (IterableDataset ):
@@ -660,6 +660,8 @@ def __len__(self):
660660 trainer .fit (model , train_dataloader = dataloader , val_dataloaders = [dataloader ])
661661 with pytest .warns (UserWarning , match = 'Your `IterableDataset` has `__len__` defined.' ):
662662 trainer .test (model , test_dataloaders = [dataloader ])
663+ with pytest .warns (UserWarning , match = 'Your `IterableDataset` has `__len__` defined.' ):
664+ trainer .predict (model , dataloaders = [dataloader ])
663665
664666 # without __len__ defined
665667 dataloader = DataLoader (IterableWithoutLen (), batch_size = 16 )
@@ -669,6 +671,7 @@ def __len__(self):
669671 trainer .validate (model , val_dataloaders = dataloader )
670672 trainer .fit (model , train_dataloader = dataloader , val_dataloaders = [dataloader ])
671673 trainer .test (model , test_dataloaders = dataloader )
674+ trainer .predict (model , dataloaders = dataloader )
672675
673676
674677@RunIf (min_gpus = 2 )
0 commit comments