Skip to content

Commit a60037d

Browse files
committed
typing hell
1 parent 24a85ce commit a60037d

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pytorch_lightning/lite/lite.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import contextmanager
1717
from functools import partial
1818
from pathlib import Path
19-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
19+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union, cast
2020

2121
import torch
2222
import torch.nn as nn
@@ -188,7 +188,7 @@ def setup(
188188

189189
def setup_dataloaders(
190190
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
191-
) -> Union[_LiteDataLoader, List[_LiteDataLoader]]:
191+
) -> Union[DataLoader, List[DataLoader]]:
192192
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
193193
dataloader, call this method individually for each one.
194194
@@ -213,7 +213,7 @@ def setup_dataloaders(
213213

214214
def _setup_dataloader(
215215
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
216-
) -> _LiteDataLoader:
216+
) -> DataLoader:
217217
"""Setup a single dataloader for accelerated training.
218218
219219
Args:
@@ -246,7 +246,9 @@ def _setup_dataloader(
246246

247247
dataloader = self._strategy.process_dataloader(dataloader)
248248
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
249-
return _LiteDataLoader(dataloader=dataloader, device=device)
249+
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device)
250+
lite_dataloader = cast(DataLoader, lite_dataloader)
251+
return lite_dataloader
250252

251253
def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
252254
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.

0 commit comments

Comments
 (0)