1616from contextlib import contextmanager
1717from functools import partial
1818from pathlib import Path
19- from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , Tuple , Union
19+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , Tuple , Union , cast
2020
2121import torch
2222import torch .nn as nn
@@ -188,7 +188,7 @@ def setup(
188188
189189 def setup_dataloaders (
190190 self , * dataloaders : DataLoader , replace_sampler : bool = True , move_to_device : bool = True
191- ) -> Union [_LiteDataLoader , List [_LiteDataLoader ]]:
191+ ) -> Union [DataLoader , List [DataLoader ]]:
192192 """Setup one or multiple dataloaders for accelerated training. If you need different settings for each
193193 dataloader, call this method individually for each one.
194194
@@ -213,7 +213,7 @@ def setup_dataloaders(
213213
214214 def _setup_dataloader (
215215 self , dataloader : DataLoader , replace_sampler : bool = True , move_to_device : bool = True
216- ) -> _LiteDataLoader :
216+ ) -> DataLoader :
217217 """Setup a single dataloader for accelerated training.
218218
219219 Args:
@@ -246,7 +246,9 @@ def _setup_dataloader(
246246
247247 dataloader = self ._strategy .process_dataloader (dataloader )
248248 device = self .device if move_to_device and not isinstance (self ._strategy , TPUSpawnPlugin ) else None
249- return _LiteDataLoader (dataloader = dataloader , device = device )
249+ lite_dataloader = _LiteDataLoader (dataloader = dataloader , device = device )
250+ lite_dataloader = cast (DataLoader , lite_dataloader )
251+ return lite_dataloader
250252
251253 def backward (self , tensor : Tensor , * args : Any , model : Optional [_LiteModule ] = None , ** kwargs : Any ) -> None :
252254 """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
0 commit comments