Skip to content
Merged
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ disable_error_code = "attr-defined"
# style choices
warn_no_return = "False"

# Changes mypy default to ignore all errors
# Ignore mypy errors for these files
# TODO: the goal is for this to be empty
[[tool.mypy.overrides]]
# the list can be generated with:
Expand All @@ -63,8 +63,6 @@ module = [
"pytorch_lightning.core.mixins.hparams_mixin",
"pytorch_lightning.core.saving",
"pytorch_lightning.distributed.dist",
"pytorch_lightning.lite.lite",
"pytorch_lightning.lite.wrappers",
"pytorch_lightning.loggers.base",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.csv_logs",
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -201,7 +201,7 @@ def setup_dataloaders(
for dataloader in dataloaders
]
dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders
return dataloaders
return dataloaders # type: ignore[return-value]

def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
Expand Down Expand Up @@ -284,6 +284,18 @@ def autocast(self) -> Generator[None, None, None]:
with self._precision_plugin.forward_context():
yield

@overload
def to_device(self, obj: nn.Module) -> nn.Module:
...

@overload
def to_device(self, obj: Tensor) -> Tensor:
...

@overload
def to_device(self, obj: Any) -> Any:
...

def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already
on that device.
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
return

for item in iterator:
yield move_data_to_device(item, self._device)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -241,7 +241,7 @@ def validation_step_end(self, output):
def test_step_end(self, output):
return output

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
"""Wraps the dataloader if necessary.

Args:
Expand Down