1414import inspect
1515import os
1616from dataclasses import fields
17- from typing import Any , Dict , Generator , Iterable , Mapping , Optional , Sized , Tuple , Union
17+ from typing import Any , Dict , Generator , Iterable , Mapping , Optional , Sized , Tuple , Union , TYPE_CHECKING
1818
1919import torch
2020from lightning_utilities .core .apply_func import is_dataclass_instance
3030 sized_len ,
3131)
3232from lightning .pytorch .overrides .distributed import _IndexBatchSamplerWrapper
33- from lightning .pytorch .trainer .states import RunningStage
3433from lightning .pytorch .utilities .exceptions import MisconfigurationException
3534from lightning .pytorch .utilities .rank_zero import rank_zero_warn , WarningCache
3635
36+ if TYPE_CHECKING :
37+ from lightning .pytorch .trainer .states import RunningStage
38+
3739BType = Union [Tensor , str , Mapping [Any , "BType" ], Iterable ["BType" ]]
3840
3941warning_cache = WarningCache ()
@@ -150,7 +152,7 @@ def has_len_all_ranks(
150152
151153
152154def _update_dataloader (
153- dataloader : DataLoader , sampler : Union [Sampler , Iterable ], mode : Optional [RunningStage ] = None
155+ dataloader : DataLoader , sampler : Union [Sampler , Iterable ], mode : Optional [" RunningStage" ] = None
154156) -> DataLoader :
155157 dl_args , dl_kwargs = _get_dataloader_init_args_and_kwargs (dataloader , sampler , mode )
156158 return _reinstantiate_wrapped_cls (dataloader , * dl_args , ** dl_kwargs )
@@ -159,7 +161,7 @@ def _update_dataloader(
159161def _get_dataloader_init_args_and_kwargs (
160162 dataloader : DataLoader ,
161163 sampler : Union [Sampler , Iterable ],
162- mode : Optional [RunningStage ] = None ,
164+ mode : Optional [" RunningStage" ] = None ,
163165) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
164166 if not isinstance (dataloader , DataLoader ):
165167 raise ValueError (f"The dataloader { dataloader } needs to subclass `torch.utils.data.DataLoader`" )
@@ -253,7 +255,7 @@ def _get_dataloader_init_args_and_kwargs(
253255def _dataloader_init_kwargs_resolve_sampler (
254256 dataloader : DataLoader ,
255257 sampler : Union [Sampler , Iterable ],
256- mode : Optional [RunningStage ] = None ,
258+ mode : Optional [" RunningStage" ] = None ,
257259) -> Dict [str , Any ]:
258260 """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-
259261 instantiation.
@@ -262,6 +264,8 @@ def _dataloader_init_kwargs_resolve_sampler(
262264 Lightning can keep track of its indices.
263265
264266 """
267+ from lightning .pytorch .trainer .states import RunningStage
268+
265269 is_predicting = mode == RunningStage .PREDICTING
266270 batch_sampler = getattr (dataloader , "batch_sampler" )
267271 batch_sampler_cls = type (batch_sampler )
0 commit comments