Skip to content

Commit ae9d7e0

Browse files
authored
Merge branch 'master' into attach-data-refactor
2 parents 841381c + b9b3fa3 commit ae9d7e0

File tree

6 files changed

+95
-16
lines changed

6 files changed

+95
-16
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
415415
- Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240))
416416

417417

418+
- Fixed handling an `IterableDataset` that fails to produce a batch at the beginning of an epoch ([#7294](https://github.com/PyTorchLightning/pytorch-lightning/pull/7294))
419+
418420

419421
- Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268))
420422

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.utils.data import DataLoader
1818

1919
import pytorch_lightning as pl
20+
from pytorch_lightning.trainer.supporters import prefetch_iterator
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122
from pytorch_lightning.utilities.model_helpers import is_overridden
2223

@@ -43,22 +44,10 @@ def on_trainer_init(
4344

4445
def get_profiled_train_dataloader(self, train_dataloader):
4546
profiled_dl = self.trainer.profiler.profile_iterable(
46-
enumerate(self._with_is_last(train_dataloader)), "get_train_batch"
47+
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
4748
)
4849
return profiled_dl
4950

50-
def _with_is_last(self, iterable):
51-
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
52-
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
53-
it = iter(iterable)
54-
last = next(it)
55-
for val in it:
56-
# yield last and has next
57-
yield last, False
58-
last = val
59-
# yield last, no longer has next
60-
yield last, True
61-
6251
def prepare_data(self, model):
6352
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
6453
# or in the case where each node needs to do its own manipulation in which case just local_rank=0

pytorch_lightning/trainer/supporters.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
from collections.abc import Iterable, Iterator, Mapping, Sequence
17-
from typing import Any, Callable, Optional, Union
17+
from typing import Any, Callable, Generator, Optional, Tuple, Union
1818

1919
import torch
2020
from torch import Tensor
@@ -508,3 +508,25 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
508508
new_data.append(x)
509509

510510
return compute_func(new_data)
511+
512+
513+
def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
514+
"""
515+
Returns an iterator that pre-fetches and caches the next item.
516+
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
517+
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
518+
"""
519+
it = iter(iterable)
520+
521+
try:
522+
# the iterator may be empty from the beginning
523+
last = next(it)
524+
except StopIteration:
525+
return
526+
527+
for val in it:
528+
# yield last and has next
529+
yield last, False
530+
last = val
531+
# yield last, no longer has next
532+
yield last, True

pytorch_lightning/trainer/training_loop.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,10 @@ def run_training_epoch(self):
444444
dataloader_idx = 0
445445
val_loop_called = False
446446

447-
for batch_idx, (batch, is_last_batch) in train_dataloader:
447+
batch_idx = None
448+
is_last_batch = None
448449

450+
for batch_idx, (batch, is_last_batch) in train_dataloader:
449451
self.trainer.batch_idx = batch_idx
450452
self.trainer.is_last_batch = is_last_batch
451453

@@ -516,6 +518,10 @@ def run_training_epoch(self):
516518
# progress global step according to grads progress
517519
self.increment_accumulated_grad_global_step()
518520

521+
if batch_idx is None:
522+
# dataloader/iterator did not produce a batch
523+
return
524+
519525
# handle epoch_output on epoch end
520526
self.on_train_epoch_end(epoch_output)
521527

tests/trainer/test_dataloaders.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,41 @@ def __len__(self):
779779
trainer.predict(model, dataloaders=dataloader)
780780

781781

782+
def test_iterable_dataset_stop_iteration_at_epoch_beginning():
783+
""" Test that the training loop skips execution if the iterator is empty from the start. """
784+
785+
class RandomDataset(IterableDataset):
786+
787+
def __init__(self, gen):
788+
self.gen = gen
789+
790+
def __iter__(self):
791+
return iter(self.gen())
792+
793+
class TestModel(BoringModel):
794+
795+
def train_dataloader(self):
796+
return DataLoader(RandomDataset(self.gen), batch_size=2)
797+
798+
def gen(self):
799+
# produce data in epoch 0
800+
# no data otherwise
801+
if self.current_epoch == 0:
802+
yield torch.rand(32)
803+
yield torch.rand(32)
804+
yield torch.rand(32)
805+
806+
model = TestModel()
807+
trainer = Trainer(
808+
default_root_dir=os.getcwd(),
809+
max_epochs=2, # we expect the second epoch to be skipped
810+
weights_summary=None,
811+
)
812+
trainer.fit(model)
813+
assert trainer.global_step == 2
814+
assert trainer.current_epoch == 1
815+
816+
782817
@RunIf(min_gpus=2)
783818
def test_dataloader_reinit_for_subclass(tmpdir):
784819

tests/trainer/test_supporters.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
import torch
2020
from torch.utils.data import DataLoader, TensorDataset
21-
from torch.utils.data.dataset import Dataset
21+
from torch.utils.data.dataset import Dataset, IterableDataset
2222
from torch.utils.data.distributed import DistributedSampler
2323
from torch.utils.data.sampler import Sampler
2424

@@ -29,6 +29,7 @@
2929
CombinedLoader,
3030
CombinedLoaderIterator,
3131
CycleIterator,
32+
prefetch_iterator,
3233
TensorRunningAccum,
3334
)
3435
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -78,6 +79,30 @@ def test_none_length_cycle_iterator():
7879
assert item == 0
7980

8081

82+
def test_prefetch_iterator():
83+
""" Test the prefetch_iterator with PyTorch IterableDataset. """
84+
85+
class IterDataset(IterableDataset):
86+
87+
def __iter__(self):
88+
yield 1
89+
yield 2
90+
yield 3
91+
92+
dataset = IterDataset()
93+
iterator = prefetch_iterator(dataset)
94+
assert [item for item in iterator] == [(1, False), (2, False), (3, True)]
95+
96+
class EmptyIterDataset(IterableDataset):
97+
98+
def __iter__(self):
99+
return iter([])
100+
101+
dataset = EmptyIterDataset()
102+
iterator = prefetch_iterator(dataset)
103+
assert [item for item in iterator] == []
104+
105+
81106
@pytest.mark.parametrize(
82107
["dataset_1", "dataset_2"],
83108
[

0 commit comments

Comments
 (0)