-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
The total of the sanity check progress bar is set by
https://github.com/PyTorchLightning/pytorch-lightning/blob/4d0406ec8bf1c9147b34eb607411b78a9cd28243/pytorch_lightning/callbacks/progress.py#L296
The progress bar will always show trainer.num_sanity_val_steps even if the length of the validation DataLoader is less than trainer.num_sanity_val_steps.
Maybe the total could be computed by
from pytorch_lightning.trainer import data_loading
num_full_val_dataloader_batches = [
len(dataloader) if data_loading._has_len(dataloader) else float('inf')
for dataloader in trainer.val_dataloaders
]
self.val_progress_bar.total = convert_inf(
sum(min(num_batches, trainer.num_sanity_val_steps)
for num_batches in num_full_val_dataloader_batches))We use the private function data_loading._has_len to check if dataloader has __len__, maybe we could make data_loading._has_len public.
Or we could make num_full_val_dataloader_batches (and num_full_train_dataloader_batches) a member variable of Trainer and update the value in pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin.
To Reproduce
The progress bar of the sanity check in the following code (num_sanity_val_steps == 999 and len(val_data_loader) == 10) shows
Validation sanity check: 1%| | 9/999 [00:09<16:31, 1.00s/it]`
Code sample
import time
import pytorch_lightning as pl
from torch.utils import data
class Dataset(data.Dataset):
def __init__(self, length):
self._elements = list(range(length))
def __getitem__(self, item):
return self._elements[item]
def __len__(self):
return len(self._elements)
class Model(pl.LightningModule):
def forward(self, *args, **kwargs):
pass
def training_step(self, *args, **kwargs):
pass
def train_dataloader(self):
pass
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
time.sleep(1)
return pl.EvalResult()
if __name__ == '__main__':
model = Model()
val_dataset_length = 10
val_dataset = Dataset(val_dataset_length)
val_data_loader = data.DataLoader(val_dataset)
trainer = pl.Trainer(num_sanity_val_steps=999, limit_val_batches=999,
max_epochs=0)
trainer.fit(model, val_dataloaders=val_data_loader)Expected behavior
The program above should be
Validation sanity check: 100%|██████████| 10/10 [00:10<00:00, 1.00s/it]
Environment
- CUDA:
- GPU:
- available:
- version:
- Packages:
- numpy: 1.18.5
- pyTorch_debug: False
- pyTorch_version: 1.6.0+cpu
- pytorch-lightning: 0.9.0rc11
- tensorboard: 1.15.0
- tqdm: 4.48.2
- System:
- OS: Windows
- architecture:
- 64bit
- WindowsPE
- processor:
- python: 3.7.3
- version: 10.0.18362