Skip to content

The total number of batches shows by the progress bar of the sanity check is wrong #2891

@manipopopo

Description

@manipopopo

🐛 Bug

The total of the sanity check progress bar is set by
https://github.com/PyTorchLightning/pytorch-lightning/blob/4d0406ec8bf1c9147b34eb607411b78a9cd28243/pytorch_lightning/callbacks/progress.py#L296

The progress bar will always show trainer.num_sanity_val_steps even if the length of the validation DataLoader is less than trainer.num_sanity_val_steps.

Maybe the total could be computed by

from pytorch_lightning.trainer import data_loading

num_full_val_dataloader_batches = [
    len(dataloader) if data_loading._has_len(dataloader) else float('inf')
    for dataloader in trainer.val_dataloaders
]
self.val_progress_bar.total = convert_inf(
    sum(min(num_batches, trainer.num_sanity_val_steps)
            for num_batches in num_full_val_dataloader_batches))

We use the private function data_loading._has_len to check if dataloader has __len__, maybe we could make data_loading._has_len public.

Or we could make num_full_val_dataloader_batches (and num_full_train_dataloader_batches) a member variable of Trainer and update the value in pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin.

To Reproduce

The progress bar of the sanity check in the following code (num_sanity_val_steps == 999 and len(val_data_loader) == 10) shows

Validation sanity check:   1%|          | 9/999 [00:09<16:31,  1.00s/it]`

Code sample

import time

import pytorch_lightning as pl
from torch.utils import data


class Dataset(data.Dataset):

  def __init__(self, length):
    self._elements = list(range(length))

  def __getitem__(self, item):
    return self._elements[item]

  def __len__(self):
    return len(self._elements)


class Model(pl.LightningModule):

  def forward(self, *args, **kwargs):
    pass

  def training_step(self, *args, **kwargs):
    pass

  def train_dataloader(self):
    pass

  def configure_optimizers(self):
    pass

  def validation_step(self, *args, **kwargs):
    time.sleep(1)
    return pl.EvalResult()


if __name__ == '__main__':
  model = Model()

  val_dataset_length = 10
  val_dataset = Dataset(val_dataset_length)
  val_data_loader = data.DataLoader(val_dataset)

  trainer = pl.Trainer(num_sanity_val_steps=999, limit_val_batches=999,
                       max_epochs=0)
  trainer.fit(model, val_dataloaders=val_data_loader)

Expected behavior

The program above should be

Validation sanity check: 100%|██████████| 10/10 [00:10<00:00,  1.00s/it]

Environment

  • CUDA:
    • GPU:
    • available:
    • version:
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0+cpu
    • pytorch-lightning: 0.9.0rc11
    • tensorboard: 1.15.0
    • tqdm: 4.48.2
  • System:
    • OS: Windows
    • architecture:
      • 64bit
      • WindowsPE
    • processor:
    • python: 3.7.3
    • version: 10.0.18362

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions