Skip to content

Commit d3f1493

Browse files
committed
Add typing to data fetching
1 parent 829db52 commit d3f1493

File tree

4 files changed

+34
-42
lines changed

4 files changed

+34
-42
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ module = [
9595
"pytorch_lightning.utilities.auto_restart",
9696
"pytorch_lightning.utilities.data",
9797
"pytorch_lightning.utilities.distributed",
98-
"pytorch_lightning.utilities.fetching",
9998
"pytorch_lightning.utilities.memory",
10099
"pytorch_lightning.utilities.meta",
101100
]

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Dict, Iterator, Optional, Union
1919

2020
from deprecate import void
21+
from torch.utils.data import DataLoader
2122

2223
from pytorch_lightning.loops.base import Loop
2324
from pytorch_lightning.trainer.progress import BatchProgress
@@ -194,7 +195,7 @@ def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> No
194195
"Reloading support hasn't been implemented for `CombinedLoader`. You can request it by opening an issue"
195196
" in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
196197
)
197-
assert dataloader is not None
198+
assert isinstance(dataloader, DataLoader)
198199
_reload_dataloader_state_dict(dataloader, self._dataloader_state_dict)
199200
self._dataloader_state_dict = {}
200201

pytorch_lightning/utilities/auto_restart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def wrapper() -> Any:
487487
def patch_dataloader_iterator(
488488
dataloader: DataLoader,
489489
iterator: Iterator,
490-
data_fetcher: "pl.utilities.fetching.DataFetcher",
490+
data_fetcher: "pl.utilities.fetching.AbstractDataFetcher",
491491
num_batches_fetched: int = 0,
492492
) -> None:
493493
"""Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is

pytorch_lightning/utilities/fetching.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Iterable, Iterator
1717
from copy import deepcopy
18-
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
18+
from typing import Any, Callable, List, Optional, Tuple
1919

2020
import torch
2121
from torch.utils.data.dataloader import DataLoader
@@ -58,28 +58,28 @@ def fetching_function(self) -> Any:
5858
def prefetching(self) -> None:
5959
"""Override with your own pre-fetching logic."""
6060

61+
def on_fetch_start(self) -> Any:
62+
"""Hook to override to handle the logic before fetching a batch."""
63+
64+
def on_fetch_end(self, batch: Any, start_output: Any) -> None:
65+
"""Hook to extend which handles the logic after fetching a batch."""
66+
67+
def wait(self) -> None:
68+
"""Hook to override to indicate the `DataFetcher` to wait for an event."""
69+
6170
def __init__(self, prefetch_batches: int = 0) -> None:
6271
if prefetch_batches < 0:
6372
raise MisconfigurationException("`prefetch_batches` should at least be 0.")
6473
self.prefetch_batches = prefetch_batches
65-
66-
self.dataloader: Optional[Union[DataLoader, CombinedLoader]] = None
74+
self.dataloader: Optional[Iterable] = None
6775
self.dataloader_iter: Optional[Iterator] = None
68-
69-
self.batch_to_device: Optional[Callable]
70-
71-
self.batches: List
72-
self.fetched: int
73-
self.done: bool
74-
76+
self.batch_to_device: Optional[Callable] = None
7577
self.reset()
7678

7779
def setup(self, dataloader: Iterable, batch_to_device: Optional[Callable] = None) -> None:
7880
self._add_capture_metadata_collate(dataloader)
79-
8081
self.dataloader = dataloader
8182
self.batch_to_device = batch_to_device
82-
8383
self._attach_data_fetcher()
8484

8585
@staticmethod
@@ -92,8 +92,8 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None:
9292

9393
apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)
9494

95-
def _apply_patch(self):
96-
def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
95+
def _apply_patch(self) -> None:
96+
def _apply_patch_fn(loader: DataLoader, iterator: Iterator) -> None:
9797
if isinstance(loader, CycleIterator):
9898
loader = loader.loader
9999
# cycle_iterator = iterator
@@ -158,13 +158,13 @@ def loader_iters(self) -> List[Iterator]:
158158

159159
@property
160160
def state(self) -> Any:
161-
def collect_state(iterator: Iterator):
161+
def collect_state(iterator: Iterator) -> Any:
162162
return iterator.state
163163

164164
return apply_to_collection(self.loader_iters, Iterator, collect_state)
165165

166-
def _attach_data_fetcher(self):
167-
def _attach_data_fetcher_fn(loader: DataLoader):
166+
def _attach_data_fetcher(self) -> None:
167+
def _attach_data_fetcher_fn(loader: DataLoader) -> None:
168168
if isinstance(loader, CycleIterator):
169169
loader = loader.loader
170170

@@ -173,7 +173,7 @@ def _attach_data_fetcher_fn(loader: DataLoader):
173173

174174
apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)
175175

176-
def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
176+
def __iter__(self) -> "AbstractDataFetcher":
177177
if self.dataloader is None:
178178
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
179179
self.reset()
@@ -184,11 +184,11 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
184184
self.prefetching()
185185
return self
186186

187-
def __next__(self):
187+
def __next__(self) -> Any:
188188
return self.fetching_function()
189189

190190
def reset(self) -> None:
191-
self.batches: List = []
191+
self.batches: List[Any] = []
192192
self.fetched: int = 0
193193
self.done: bool = False
194194

@@ -217,18 +217,13 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N
217217
super().__init__(prefetch_batches=prefetch_batches)
218218
self.store_on_device = store_on_device
219219

220-
def on_fetch_start(self) -> None:
221-
"""Hook to override to handle the logic before fetching a batch."""
222-
223-
def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None:
220+
def on_fetch_end(self, batch: Any, start_output: Any) -> None:
224221
"""Hook to extend which handles the logic after fetching a batch."""
225222
self.batches.append(batch)
226223

227-
def wait(self) -> None:
228-
"""Hook to override to indicate the `DataFetcher` to wait for an event."""
229-
230224
def prefetching(self) -> None:
231225
iterator = self.dataloader_iter
226+
assert iterator is not None
232227
for _ in range(self.prefetch_batches):
233228
try:
234229
self._fetch_next_batch(iterator)
@@ -282,23 +277,21 @@ class InterBatchParallelDataFetcher(DataFetcher):
282277
batch 2: [HtoD] [forward][backward]
283278
"""
284279

285-
def __init__(self, *args, **kwargs) -> None:
280+
def __init__(self, *args: Any, **kwargs: Any) -> None:
286281
super().__init__(*args, **kwargs)
287282
self.cuda_stream = torch.cuda.Stream()
288283
self.events: List[torch.cuda.Event] = []
289284

290-
def move_to_device(self, batch):
285+
def move_to_device(self, batch: Any) -> Any:
291286
with torch.cuda.stream(self.cuda_stream):
292287
return super().move_to_device(batch)
293288

294289
def on_fetch_start(self) -> "torch.cuda.Event":
295290
# create a cuda event used to record the async stream of data to device.
296291
return torch.cuda.Event()
297292

298-
def on_fetch_end(self, batch, event: torch.cuda.Event) -> None:
299-
super().on_fetch_end(batch)
300-
301-
# record event and store the event
293+
def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None:
294+
self.batches.append(batch)
302295
event.record()
303296
self.events.append(event)
304297

@@ -308,7 +301,7 @@ def wait(self) -> None:
308301
event.wait()
309302

310303

311-
class StepFuncDataLoaderIter:
304+
class StepFuncDataLoaderIter(Iterator):
312305

313306
"""This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user
314307
control."""
@@ -317,9 +310,6 @@ def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"):
317310
self.iterator = iterator
318311
self.data_fetcher = data_fetcher
319312

320-
def __iter__(self) -> "StepFuncDataLoaderIter":
321-
return self
322-
323313
def __next__(self) -> Any:
324314
try:
325315
data = next(self.iterator)
@@ -349,12 +339,14 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
349339
...
350340
"""
351341

352-
def __init__(self):
342+
def __init__(self) -> None:
353343
super().__init__()
354344
self.store_on_device = False
355345

356346
def prefetching(self) -> None:
357-
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
347+
iterator = self.dataloader_iter
348+
assert iterator is not None
349+
self.iterator = iter(StepFuncDataLoaderIter(iterator, self))
358350

359351
def fetching_function(self) -> Tuple[int, Tuple[Iterator, bool]]:
360352
if not self.done:

0 commit comments

Comments
 (0)