Skip to content

Epochs terminating early incorrectly #12956

@jasonxzhou

Description

@jasonxzhou

🐛 Bug

My understanding is that custom dataloaders are expected to reset themselves and throw StopIteration when __next__ is called and there is nothing more to yield (e.g. at end of an epoch).
However, starting from pytorch-lightning == 1.6.0, Lightning appears to recognize that there are no more batches to yield, and this final __next__ call at the end of the epoch is no longer made. Subsequently, on the next epoch, when the first such __next__ call is made, the reset and StopIteration steps are triggered and the epoch ends immediately.
This results in every other epoch being skipped/terminating early:

Time for epoch 0: 0.14461779594421387
Time for epoch 1: 0.000400543212890625   <- skipped 
Time for epoch 2: 0.11101579666137695
Time for epoch 3: 0.00034236907958984375 <- skipped 
Time for epoch 4: 0.11014986038208008
Time for epoch 5: 0.0003437995910644531  <- skipped 
Time for epoch 6: 0.1101064682006836
Time for epoch 7: 0.00034737586975097656 <- skipped 
Time for epoch 8: 0.1124570369720459
Time for epoch 9: 0.0006518363952636719  <- skipped 

To Reproduce

from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule

import torch
import os
import time 

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        x = x[0]['random'].float() # unpack random data
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)
    
    def on_train_epoch_start(self):
        self.start_time = time.time()
    
    def on_train_epoch_end(self):
        print(f'Time for epoch {self.current_epoch}: {time.time() - self.start_time}')

    def setup(self, stage=None):
        device_id = self.local_rank
        shard_id = self.global_rank
        num_shards = self.trainer.world_size
        mnist_pipeline = BoringPipeline(batch_size=2, device='gpu', device_id=device_id, shard_id=shard_id, num_shards=num_shards, num_threads=8)
        self.train_loader = DALIGenericIterator(mnist_pipeline, output_map=['random'], size=100, last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True)

    def train_dataloader(self):
        return self.train_loader

@pipeline_def
def BoringPipeline(device, shard_id, num_shards):
    return fn.random.coin_flip(shape=32)

def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_val_batches=0,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=10,
        enable_model_summary=False,
    )
    trainer.fit(model)

if __name__ == "__main__":
    run()

Expected behavior

On pytorch-lightning==1.5.10, there is no such issue, and we get the following times:

Time for epoch 0: 0.12346410751342773
Time for epoch 1: 0.09711241722106934
Time for epoch 2: 0.09668135643005371
Time for epoch 3: 0.09659314155578613
Time for epoch 4: 0.09643316268920898
Time for epoch 5: 0.0983741283416748
Time for epoch 6: 0.09633684158325195
Time for epoch 7: 0.09726572036743164
Time for epoch 8: 0.09633612632751465
Time for epoch 9: 0.10026073455810547

Environment

* CUDA:
        - GPU:
                - NVIDIA A40
                - NVIDIA A40
                - NVIDIA A40
                - NVIDIA A40
                - NVIDIA A40
                - NVIDIA A40
                - NVIDIA A40
                - NVIDIA A40
        - available:         True
        - version:           11.5
* Packages:
        - numpy:             1.22.3
        - pyTorch_debug:     False
        - pyTorch_version:   1.11.0+cu115
        - pytorch-lightning: 1.6.2
        - tqdm:              4.64.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.10
        - version:           #1 SMP Debian 5.15.15-2~bpo11+1 (2022-02-03)

Additional context

cc @justusschock @awaelchli @ninginthecloud @rohitgr7

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdata handlingGeneric data-related topic

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions