-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Description
🐛 Bug
When using IterableDataset, training is getting terminated before it starts
Example is given by Pytorch-Lightning: https://www.kaggle.com/pytorchlightning/pytorch-on-tpu-with-pytorch-lightning
To Reproduce
Steps to reproduce the behavior:
pip installthe following- cloud-tpu-client==0.10
- https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
- pytorch-lightning
- torchvision
(Note: Install procedure is also given here)
- Keep the LightningModule unchanged
- Change the
Datamodule as below
class CustomeDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
for data in self.dataset:
yield data
# train/val split
mnist_train = MNIST("", train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
mnist_test = MNIST("", train=False, download=True, transform=transforms.ToTensor())
mnist_train = DataLoader(CustomeDataset(mnist_train), batch_size=64, num_workers=4)
mnist_val = DataLoader(CustomeDataset(mnist_val), batch_size=64, num_workers=4)
mnist_test = DataLoader(CustomeDataset(mnist_test), batch_size=64, num_workers=4)
- Train
image_classifier = ImageClassifier()
trainer = pl.Trainer(
tpu_cores=8,
max_epochs=10,
auto_select_gpus=True,
val_check_interval=100,
limit_train_batches=100,
weights_save_path="output",
)
trainer.fit(image_classifier, mnist_train, mnist_val)
Expected behavior
Training should happen without any errors
Environment
- Reproducible on XLA backend [CPU/TPU]: TPU
- torch_xla version: 1.8
Additional context
- This example is working fine in
CPUandGPU - Hope, you guys are aware of this issue. There has been some discussion on this Distributed TPU Training, training data stored in GCS #2690
- I am also attaching the error trace for more understanding
Metadata
Metadata
Assignees
Labels
No labels