-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic
Milestone
Description
🐛 Bug
My understanding is that custom dataloaders are expected to reset themselves and throw StopIteration when __next__ is called and there is nothing more to yield (e.g. at end of an epoch).
However, starting from pytorch-lightning == 1.6.0, Lightning appears to recognize that there are no more batches to yield, and this final __next__ call at the end of the epoch is no longer made. Subsequently, on the next epoch, when the first such __next__ call is made, the reset and StopIteration steps are triggered and the epoch ends immediately.
This results in every other epoch being skipped/terminating early:
Time for epoch 0: 0.14461779594421387
Time for epoch 1: 0.000400543212890625 <- skipped
Time for epoch 2: 0.11101579666137695
Time for epoch 3: 0.00034236907958984375 <- skipped
Time for epoch 4: 0.11014986038208008
Time for epoch 5: 0.0003437995910644531 <- skipped
Time for epoch 6: 0.1101064682006836
Time for epoch 7: 0.00034737586975097656 <- skipped
Time for epoch 8: 0.1124570369720459
Time for epoch 9: 0.0006518363952636719 <- skipped
To Reproduce
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
import torch
import os
import time
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = x[0]['random'].float() # unpack random data
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def on_train_epoch_start(self):
self.start_time = time.time()
def on_train_epoch_end(self):
print(f'Time for epoch {self.current_epoch}: {time.time() - self.start_time}')
def setup(self, stage=None):
device_id = self.local_rank
shard_id = self.global_rank
num_shards = self.trainer.world_size
mnist_pipeline = BoringPipeline(batch_size=2, device='gpu', device_id=device_id, shard_id=shard_id, num_shards=num_shards, num_threads=8)
self.train_loader = DALIGenericIterator(mnist_pipeline, output_map=['random'], size=100, last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True)
def train_dataloader(self):
return self.train_loader
@pipeline_def
def BoringPipeline(device, shard_id, num_shards):
return fn.random.coin_flip(shape=32)
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_val_batches=0,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=10,
enable_model_summary=False,
)
trainer.fit(model)
if __name__ == "__main__":
run()
Expected behavior
On pytorch-lightning==1.5.10, there is no such issue, and we get the following times:
Time for epoch 0: 0.12346410751342773
Time for epoch 1: 0.09711241722106934
Time for epoch 2: 0.09668135643005371
Time for epoch 3: 0.09659314155578613
Time for epoch 4: 0.09643316268920898
Time for epoch 5: 0.0983741283416748
Time for epoch 6: 0.09633684158325195
Time for epoch 7: 0.09726572036743164
Time for epoch 8: 0.09633612632751465
Time for epoch 9: 0.10026073455810547
Environment
* CUDA:
- GPU:
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- available: True
- version: 11.5
* Packages:
- numpy: 1.22.3
- pyTorch_debug: False
- pyTorch_version: 1.11.0+cu115
- pytorch-lightning: 1.6.2
- tqdm: 4.64.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- version: #1 SMP Debian 5.15.15-2~bpo11+1 (2022-02-03)
Additional context
awaelchli
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic