Skip to content

IterableDataset has no len() #2866

@raman-r-4978

Description

@raman-r-4978

🐛 Bug

When using IterableDataset, training is getting terminated before it starts

Example is given by Pytorch-Lightning: https://www.kaggle.com/pytorchlightning/pytorch-on-tpu-with-pytorch-lightning

To Reproduce

Steps to reproduce the behavior:

  1. pip install the following

(Note: Install procedure is also given here)

  1. Keep the LightningModule unchanged
  2. Change the Data module as below
class CustomeDataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):
        for data in self.dataset:
            yield data

# train/val split
mnist_train = MNIST("", train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

mnist_test = MNIST("", train=False, download=True, transform=transforms.ToTensor())

mnist_train = DataLoader(CustomeDataset(mnist_train), batch_size=64, num_workers=4)
mnist_val = DataLoader(CustomeDataset(mnist_val), batch_size=64, num_workers=4)
mnist_test = DataLoader(CustomeDataset(mnist_test), batch_size=64, num_workers=4)
  1. Train
image_classifier = ImageClassifier()
trainer = pl.Trainer(
    tpu_cores=8,
    max_epochs=10,
    auto_select_gpus=True,
    val_check_interval=100,
    limit_train_batches=100,
    weights_save_path="output",
)
trainer.fit(image_classifier, mnist_train, mnist_val)

Expected behavior

Training should happen without any errors

Environment

  • Reproducible on XLA backend [CPU/TPU]: TPU
  • torch_xla version: 1.8

Additional context

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions