Skip to content

Commit db26e08

Browse files
authored
Close profiler when StopIteration is raised (#14945)
1 parent d7af8ce commit db26e08

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
300300
- `SaveConfigCallback` instances should only save the config once to allow having the `overwrite=False` safeguard when using `LightningCLI(..., run=False)` ([#14927](https://github.com/Lightning-AI/lightning/pull/14927))
301301

302302

303+
304+
- Fixed an issue with terminating the trainer profiler when a `StopIteration` exception is raised while using an `IterableDataset` ([#14940](https://github.com/Lightning-AI/lightning/pull/14945))
305+
306+
303307
## [1.7.7] - 2022-09-22
304308

305309
### Fixed

src/pytorch_lightning/utilities/fetching.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,11 @@ def fetching_function(self) -> Any:
276276

277277
def _fetch_next_batch(self, iterator: Iterator) -> None:
278278
start_output = self.on_fetch_start()
279-
batch = next(iterator)
279+
try:
280+
batch = next(iterator)
281+
except StopIteration as e:
282+
self._stop_profiler()
283+
raise e
280284
self.fetched += 1
281285
if not self.prefetch_batches and self._has_len:
282286
# when we don't prefetch but the dataloader is sized, we use the length for `done`

tests/tests_pytorch/utilities/test_fetching.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,31 @@ def generate():
7878
assert fetcher.fetched == 3
7979

8080

81+
@pytest.mark.parametrize("use_combined_loader", [False, True])
82+
def test_profiler_closing(use_combined_loader):
83+
"""Tests if the profiler terminates upon raising a StopIteration on an iterable dataset."""
84+
85+
class TestDataset(IterableDataset):
86+
def __init__(self):
87+
self.list = list(range(1))
88+
89+
def __iter__(self):
90+
return iter(self.list)
91+
92+
fetcher = DataFetcher()
93+
if use_combined_loader:
94+
loader = CombinedLoader([DataLoader(TestDataset()), DataLoader(TestDataset())])
95+
else:
96+
loader = DataLoader(TestDataset())
97+
fetcher.setup(loader)
98+
profiler = SimpleProfiler()
99+
fetcher._start_profiler = lambda: profiler.start("test")
100+
fetcher._stop_profiler = lambda: profiler.stop("test")
101+
iter(fetcher) # on epoch 0 start
102+
next(fetcher) # raises StopIteration exception
103+
assert not bool(profiler.current_actions)
104+
105+
81106
class EmptyIterDataset(IterableDataset):
82107
def __iter__(self):
83108
return iter([])

0 commit comments

Comments
 (0)