Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ module = [
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.fetching",
"pytorch_lightning.utilities.memory",
"pytorch_lightning.utilities.meta",
]
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Dict, Iterator, Optional

from deprecate import void
from torch.utils.data import DataLoader

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.trainer.progress import BatchProgress
Expand Down Expand Up @@ -195,7 +196,7 @@ def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> No
"Reloading support hasn't been implemented for `CombinedLoader`. You can request it by opening an issue"
" in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
)
assert dataloader is not None
assert isinstance(dataloader, DataLoader)
_reload_dataloader_state_dict(dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = {}

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def wrapper() -> Any:
def patch_dataloader_iterator(
dataloader: DataLoader,
iterator: Iterator,
data_fetcher: "pl.utilities.fetching.DataFetcher",
data_fetcher: "pl.utilities.fetching.AbstractDataFetcher",
num_batches_fetched: int = 0,
) -> None:
"""Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is
Expand Down
125 changes: 65 additions & 60 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from copy import deepcopy
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch.utils.data.dataloader import DataLoader
Expand Down Expand Up @@ -58,29 +58,35 @@ def fetching_function(self) -> Any:
def prefetching(self) -> None:
"""Override with your own pre-fetching logic."""

def on_fetch_start(self) -> Any:
"""Hook to override to handle the logic before fetching a batch."""

def on_fetch_end(self, batch: Any, start_output: Any) -> None:
"""Hook to extend which handles the logic after fetching a batch."""

def wait(self) -> None:
"""Hook to override to indicate the `DataFetcher` to wait for an event."""

def __init__(self, prefetch_batches: int = 0) -> None:
if prefetch_batches < 0:
raise MisconfigurationException("`prefetch_batches` should at least be 0.")
self.prefetch_batches = prefetch_batches

self.dataloader: Optional[Union[DataLoader, CombinedLoader]] = None
self._dataloader: Optional[Iterable] = None
self.dataloader_iter: Optional[Iterator] = None
self.fetched: int = 0
self.done: bool = False

self.batch_to_device: Optional[Callable]

self.batches: List
self.fetched: int
self.done: bool

self.reset()

def setup(self, dataloader: Iterable, batch_to_device: Optional[Callable] = None) -> None:
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
self._add_capture_metadata_collate(dataloader)
self._dataloader = dataloader

self.dataloader = dataloader
self.batch_to_device = batch_to_device

self._attach_data_fetcher()
@property
def dataloader(self) -> Iterable:
if self._dataloader is None:
raise MisconfigurationException(
f"`{self.__class__.__name__}` should have been `setup` with a dataloader iterable."
)
return self._dataloader

@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
Expand All @@ -92,8 +98,8 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None:

apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)

def _apply_patch(self):
def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
def _apply_patch(self) -> None:
def _apply_patch_fn(loader: DataLoader, iterator: Iterator) -> None:
if isinstance(loader, CycleIterator):
loader = loader.loader
# cycle_iterator = iterator
Expand Down Expand Up @@ -130,10 +136,6 @@ def _store_dataloader_iter_state(

@property
def loaders(self) -> List[DataLoader]:
if self.dataloader is None:
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)
if isinstance(self.dataloader, CombinedLoader):
loaders = self.dataloader.loaders
else:
Expand All @@ -142,11 +144,6 @@ def loaders(self) -> List[DataLoader]:

@property
def loader_iters(self) -> List[Iterator]:
if self.dataloader is None:
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)

if self.dataloader_iter is None:
raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.")

Expand All @@ -157,14 +154,14 @@ def loader_iters(self) -> List[Iterator]:
return loader_iters

@property
def state(self) -> Any:
def collect_state(iterator: Iterator):
def state(self) -> List[MergedIteratorState]:
def collect_state(iterator: Iterator) -> MergedIteratorState:
return iterator.state

return apply_to_collection(self.loader_iters, Iterator, collect_state)

def _attach_data_fetcher(self):
def _attach_data_fetcher_fn(loader: DataLoader):
def _attach_data_fetcher(self) -> None:
def _attach_data_fetcher_fn(loader: DataLoader) -> None:
if isinstance(loader, CycleIterator):
loader = loader.loader

Expand All @@ -173,9 +170,7 @@ def _attach_data_fetcher_fn(loader: DataLoader):

apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)

def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
def __iter__(self) -> "AbstractDataFetcher":
self.reset()
self._attach_data_fetcher()
_patch_dataloader_get_iterators()
Expand All @@ -184,13 +179,12 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
self.prefetching()
return self

def __next__(self):
def __next__(self) -> Any:
return self.fetching_function()

def reset(self) -> None:
self.batches: List = []
self.fetched: int = 0
self.done: bool = False
self.fetched = 0
self.done = False

def teardown(self) -> None:
self.reset()
Expand All @@ -202,6 +196,10 @@ def teardown(self) -> None:
_teardown_dataloader_get_iterators()


def _no_op_batch_to_device(batch: Any) -> Any:
return batch


class DataFetcher(AbstractDataFetcher):

"""This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a
Expand All @@ -214,21 +212,27 @@ class DataFetcher(AbstractDataFetcher):
"""

def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None:
if prefetch_batches < 1:
raise MisconfigurationException("`prefetch_batches` should at least be 1.")
super().__init__(prefetch_batches=prefetch_batches)
self.store_on_device = store_on_device
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
self.batches: List[Any] = []

def on_fetch_start(self) -> None:
"""Hook to override to handle the logic before fetching a batch."""
def setup( # type: ignore[override]
self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None
) -> None:
super().setup(dataloader)
if batch_to_device is not None:
self.batch_to_device = batch_to_device

def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None:
def on_fetch_end(self, batch: Any, start_output: Any) -> None:
"""Hook to extend which handles the logic after fetching a batch."""
self.batches.append(batch)

def wait(self) -> None:
"""Hook to override to indicate the `DataFetcher` to wait for an event."""

def prefetching(self) -> None:
iterator = self.dataloader_iter
assert iterator is not None
for _ in range(self.prefetch_batches):
try:
self._fetch_next_batch(iterator)
Expand Down Expand Up @@ -257,10 +261,14 @@ def _fetch_next_batch(self, iterator: Iterator) -> None:
self.on_fetch_end(batch, start_output)

def move_to_device(self, batch: Any) -> Any:
if self.store_on_device and self.batch_to_device is not None:
if self.store_on_device:
batch = self.batch_to_device(batch)
return batch

def reset(self) -> None:
super().reset()
self.batches = []


class InterBatchParallelDataFetcher(DataFetcher):

Expand All @@ -282,23 +290,21 @@ class InterBatchParallelDataFetcher(DataFetcher):
batch 2: [HtoD] [forward][backward]
"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cuda_stream = torch.cuda.Stream()
self.events: List[torch.cuda.Event] = []

def move_to_device(self, batch):
def move_to_device(self, batch: Any) -> Any:
with torch.cuda.stream(self.cuda_stream):
return super().move_to_device(batch)

def on_fetch_start(self) -> "torch.cuda.Event":
# create a cuda event used to record the async stream of data to device.
return torch.cuda.Event()

def on_fetch_end(self, batch, event: torch.cuda.Event) -> None:
super().on_fetch_end(batch)

# record event and store the event
def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None:
self.batches.append(batch)
event.record()
self.events.append(event)

Expand All @@ -308,26 +314,23 @@ def wait(self) -> None:
event.wait()


class StepFuncDataLoaderIter:
class StepFuncDataLoaderIter(Iterator):

"""This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user
control."""

def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"):
def __init__(self, iterator: Iterator, data_fetcher: AbstractDataFetcher) -> None:
self.iterator = iterator
self.data_fetcher = data_fetcher

def __iter__(self) -> "StepFuncDataLoaderIter":
return self

def __next__(self) -> Any:
try:
data = next(self.iterator)
self.data_fetcher.fetched += 1
return data
except StopIteration:
except StopIteration as e:
self.data_fetcher.done = True
raise StopIteration
raise e


class DataLoaderIterDataFetcher(AbstractDataFetcher):
Expand All @@ -349,12 +352,14 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
...
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.store_on_device = False

def prefetching(self) -> None:
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
iterator = self.dataloader_iter
assert iterator is not None
self.iterator = iter(StepFuncDataLoaderIter(iterator, self))

def fetching_function(self) -> Tuple[int, Tuple[Iterator, bool]]:
if not self.done:
Expand Down