diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index bb07c763156aa..c26783c4d8bb9 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -188,7 +188,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[Iterable, List[Iterable]]: + ) -> Union[DataLoader, List[DataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -213,7 +213,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Iterable: + ) -> DataLoader: """Setup a single dataloader for accelerated training. Args: @@ -246,7 +246,9 @@ def _setup_dataloader( dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None - return _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = cast(DataLoader, lite_dataloader) + return lite_dataloader def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index ff95e89d1d2cf..938eb72afe622 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -15,7 +15,7 @@ import inspect from contextlib import contextmanager from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -157,29 +157,26 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: - def __init__(self, dataloader: Union[Iterable, DataLoader], device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if - the device is specified. + def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: + """The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the + device automatically if the device is specified. Args: - dataloader: The current dataloader to be used. + dataloader: The dataloader to wrap device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ - super().__init__() - self.__dict__.update(getattr(dataloader, "__dict__", {})) + self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device - def __len__(self) -> Union[int, float]: - if isinstance(self._dataloader, Sized): - return len(self._dataloader) - return float("inf") - @property def device(self) -> Optional[torch.device]: return self._device + def __len__(self) -> int: + return len(self._dataloader) + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: