diff --git a/pyproject.toml b/pyproject.toml index bd151cb468a5e..492a88a3c1bd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,6 @@ module = [ "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", "pytorch_lightning.utilities.distributed", - "pytorch_lightning.utilities.fetching", "pytorch_lightning.utilities.memory", "pytorch_lightning.utilities.meta", ] diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 4c2ddcf85020a..da36544d2ba6a 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Iterator, Optional from deprecate import void +from torch.utils.data import DataLoader from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import BatchProgress @@ -195,7 +196,7 @@ def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> No "Reloading support hasn't been implemented for `CombinedLoader`. You can request it by opening an issue" " in `https://github.com/PyTorchLightning/pytorch-lightning/issues`." ) - assert dataloader is not None + assert isinstance(dataloader, DataLoader) _reload_dataloader_state_dict(dataloader, self._dataloader_state_dict) self._dataloader_state_dict = {} diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ec630f795d8cc..7e62ebc5d3d60 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -487,7 +487,7 @@ def wrapper() -> Any: def patch_dataloader_iterator( dataloader: DataLoader, iterator: Iterator, - data_fetcher: "pl.utilities.fetching.DataFetcher", + data_fetcher: "pl.utilities.fetching.AbstractDataFetcher", num_batches_fetched: int = 0, ) -> None: """Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index ccbbce79829de..5b8012468dfef 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from copy import deepcopy -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple import torch from torch.utils.data.dataloader import DataLoader @@ -58,29 +58,35 @@ def fetching_function(self) -> Any: def prefetching(self) -> None: """Override with your own pre-fetching logic.""" + def on_fetch_start(self) -> Any: + """Hook to override to handle the logic before fetching a batch.""" + + def on_fetch_end(self, batch: Any, start_output: Any) -> None: + """Hook to extend which handles the logic after fetching a batch.""" + + def wait(self) -> None: + """Hook to override to indicate the `DataFetcher` to wait for an event.""" + def __init__(self, prefetch_batches: int = 0) -> None: if prefetch_batches < 0: raise MisconfigurationException("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches - - self.dataloader: Optional[Union[DataLoader, CombinedLoader]] = None + self._dataloader: Optional[Iterable] = None self.dataloader_iter: Optional[Iterator] = None + self.fetched: int = 0 + self.done: bool = False - self.batch_to_device: Optional[Callable] - - self.batches: List - self.fetched: int - self.done: bool - - self.reset() - - def setup(self, dataloader: Iterable, batch_to_device: Optional[Callable] = None) -> None: + def setup(self, dataloader: Iterable, **kwargs: Any) -> None: self._add_capture_metadata_collate(dataloader) + self._dataloader = dataloader - self.dataloader = dataloader - self.batch_to_device = batch_to_device - - self._attach_data_fetcher() + @property + def dataloader(self) -> Iterable: + if self._dataloader is None: + raise MisconfigurationException( + f"`{self.__class__.__name__}` should have been `setup` with a dataloader iterable." + ) + return self._dataloader @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: @@ -92,8 +98,8 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None: apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate) - def _apply_patch(self): - def _apply_patch_fn(loader: DataLoader, iterator: Iterator): + def _apply_patch(self) -> None: + def _apply_patch_fn(loader: DataLoader, iterator: Iterator) -> None: if isinstance(loader, CycleIterator): loader = loader.loader # cycle_iterator = iterator @@ -130,10 +136,6 @@ def _store_dataloader_iter_state( @property def loaders(self) -> List[DataLoader]: - if self.dataloader is None: - raise MisconfigurationException( - "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." - ) if isinstance(self.dataloader, CombinedLoader): loaders = self.dataloader.loaders else: @@ -142,11 +144,6 @@ def loaders(self) -> List[DataLoader]: @property def loader_iters(self) -> List[Iterator]: - if self.dataloader is None: - raise MisconfigurationException( - "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." - ) - if self.dataloader_iter is None: raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.") @@ -157,14 +154,14 @@ def loader_iters(self) -> List[Iterator]: return loader_iters @property - def state(self) -> Any: - def collect_state(iterator: Iterator): + def state(self) -> List[MergedIteratorState]: + def collect_state(iterator: Iterator) -> MergedIteratorState: return iterator.state return apply_to_collection(self.loader_iters, Iterator, collect_state) - def _attach_data_fetcher(self): - def _attach_data_fetcher_fn(loader: DataLoader): + def _attach_data_fetcher(self) -> None: + def _attach_data_fetcher_fn(loader: DataLoader) -> None: if isinstance(loader, CycleIterator): loader = loader.loader @@ -173,9 +170,7 @@ def _attach_data_fetcher_fn(loader: DataLoader): apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn) - def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: - if self.dataloader is None: - raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") + def __iter__(self) -> "AbstractDataFetcher": self.reset() self._attach_data_fetcher() _patch_dataloader_get_iterators() @@ -184,13 +179,12 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: self.prefetching() return self - def __next__(self): + def __next__(self) -> Any: return self.fetching_function() def reset(self) -> None: - self.batches: List = [] - self.fetched: int = 0 - self.done: bool = False + self.fetched = 0 + self.done = False def teardown(self) -> None: self.reset() @@ -202,6 +196,10 @@ def teardown(self) -> None: _teardown_dataloader_get_iterators() +def _no_op_batch_to_device(batch: Any) -> Any: + return batch + + class DataFetcher(AbstractDataFetcher): """This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a @@ -214,21 +212,27 @@ class DataFetcher(AbstractDataFetcher): """ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None: + if prefetch_batches < 1: + raise MisconfigurationException("`prefetch_batches` should at least be 1.") super().__init__(prefetch_batches=prefetch_batches) self.store_on_device = store_on_device + self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device + self.batches: List[Any] = [] - def on_fetch_start(self) -> None: - """Hook to override to handle the logic before fetching a batch.""" + def setup( # type: ignore[override] + self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None + ) -> None: + super().setup(dataloader) + if batch_to_device is not None: + self.batch_to_device = batch_to_device - def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None: + def on_fetch_end(self, batch: Any, start_output: Any) -> None: """Hook to extend which handles the logic after fetching a batch.""" self.batches.append(batch) - def wait(self) -> None: - """Hook to override to indicate the `DataFetcher` to wait for an event.""" - def prefetching(self) -> None: iterator = self.dataloader_iter + assert iterator is not None for _ in range(self.prefetch_batches): try: self._fetch_next_batch(iterator) @@ -257,10 +261,14 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: - if self.store_on_device and self.batch_to_device is not None: + if self.store_on_device: batch = self.batch_to_device(batch) return batch + def reset(self) -> None: + super().reset() + self.batches = [] + class InterBatchParallelDataFetcher(DataFetcher): @@ -282,12 +290,12 @@ class InterBatchParallelDataFetcher(DataFetcher): batch 2: [HtoD] [forward][backward] """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.cuda_stream = torch.cuda.Stream() self.events: List[torch.cuda.Event] = [] - def move_to_device(self, batch): + def move_to_device(self, batch: Any) -> Any: with torch.cuda.stream(self.cuda_stream): return super().move_to_device(batch) @@ -295,10 +303,8 @@ def on_fetch_start(self) -> "torch.cuda.Event": # create a cuda event used to record the async stream of data to device. return torch.cuda.Event() - def on_fetch_end(self, batch, event: torch.cuda.Event) -> None: - super().on_fetch_end(batch) - - # record event and store the event + def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None: + self.batches.append(batch) event.record() self.events.append(event) @@ -308,26 +314,23 @@ def wait(self) -> None: event.wait() -class StepFuncDataLoaderIter: +class StepFuncDataLoaderIter(Iterator): """This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user control.""" - def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"): + def __init__(self, iterator: Iterator, data_fetcher: AbstractDataFetcher) -> None: self.iterator = iterator self.data_fetcher = data_fetcher - def __iter__(self) -> "StepFuncDataLoaderIter": - return self - def __next__(self) -> Any: try: data = next(self.iterator) self.data_fetcher.fetched += 1 return data - except StopIteration: + except StopIteration as e: self.data_fetcher.done = True - raise StopIteration + raise e class DataLoaderIterDataFetcher(AbstractDataFetcher): @@ -349,12 +352,14 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: ... """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.store_on_device = False def prefetching(self) -> None: - self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) + iterator = self.dataloader_iter + assert iterator is not None + self.iterator = iter(StepFuncDataLoaderIter(iterator, self)) def fetching_function(self) -> Tuple[int, Tuple[Iterator, bool]]: if not self.done: