-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
Please reproduce using the BoringModel
when only limit_train_batches is set for trainer for Iterable dataset with no __len__, it does not work for progress bar, as the progress bar's total number would be nan, hence https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/progress.py#L412 would crash
To Reproduce
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, IterableDataset, DataLoader
def example():
return {"x": torch.randn(20), "y": torch.ones(1).long()}
class RegularIterableDataset(IterableDataset):
def __init__(self, num_examples):
super().__init__()
self.num_examples = num_examples
def __iter__(self):
for _ in range(self.num_examples):
yield example()
import torch.nn.functional as F
class BasicModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(20, 10)
def forward(self, x):
return torch.relu(self.l1(x))
def training_step(self, batch, batch_idx):
x = batch["x"]
y = batch["y"]
y_hat = self(x)
loss = F.cross_entropy(y_hat, y.view(-1))
self.log("loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
return self.training_step(batch, batch_idx)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
import inspect
from pytorch_lightning.callbacks.progress import ProgressBar
import math
class ProgressBarCheckingCallback(ProgressBar):
def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
if self.main_progress_bar.total is not None:
assert not math.isnan(self.main_progress_bar.total)
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
if self.main_progress_bar.total is not None:
assert not math.isnan(self.main_progress_bar.total) # <- this assertion breaks, which will cause the on_train_batch_end break
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # <- because of the nan bug: https://fburl.com/j1s8uvum will crash
trainer = pl.Trainer(
max_epochs=2,
limit_train_batches=3,
callbacks=[
ProgressBarCheckingCallback(refresh_rate=1, process_position=0)
]
)
model = BasicModel()
train_dl = DataLoader(dataset=RegularIterableDataset(num_examples=128), batch_size=8)
val_dl = DataLoader(dataset=RegularIterableDataset(num_examples=128), batch_size=8)
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
Use following BoringModel and post here
Expected behavior
Environment
Note: Bugs with code are solved faster ! Colab Notebook should be made public !
-
IDE: Please, use our python bug_report_model.py template. -
Colab Notebook: Please copy and paste the output from our environment collection script (or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on