Skip to content

Commit c1dccae

Browse files
committed
circular import
1 parent 0467092 commit c1dccae

File tree

1 file changed

+9
-5
lines changed
  • src/lightning/pytorch/utilities

1 file changed

+9
-5
lines changed

src/lightning/pytorch/utilities/data.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import inspect
1515
import os
1616
from dataclasses import fields
17-
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union
17+
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union, TYPE_CHECKING
1818

1919
import torch
2020
from lightning_utilities.core.apply_func import is_dataclass_instance
@@ -30,10 +30,12 @@
3030
sized_len,
3131
)
3232
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
33-
from lightning.pytorch.trainer.states import RunningStage
3433
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3534
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
3635

36+
if TYPE_CHECKING:
37+
from lightning.pytorch.trainer.states import RunningStage
38+
3739
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
3840

3941
warning_cache = WarningCache()
@@ -150,7 +152,7 @@ def has_len_all_ranks(
150152

151153

152154
def _update_dataloader(
153-
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
155+
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional["RunningStage"] = None
154156
) -> DataLoader:
155157
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
156158
return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs)
@@ -159,7 +161,7 @@ def _update_dataloader(
159161
def _get_dataloader_init_args_and_kwargs(
160162
dataloader: DataLoader,
161163
sampler: Union[Sampler, Iterable],
162-
mode: Optional[RunningStage] = None,
164+
mode: Optional["RunningStage"] = None,
163165
) -> Tuple[Tuple[Any], Dict[str, Any]]:
164166
if not isinstance(dataloader, DataLoader):
165167
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
@@ -253,7 +255,7 @@ def _get_dataloader_init_args_and_kwargs(
253255
def _dataloader_init_kwargs_resolve_sampler(
254256
dataloader: DataLoader,
255257
sampler: Union[Sampler, Iterable],
256-
mode: Optional[RunningStage] = None,
258+
mode: Optional["RunningStage"] = None,
257259
) -> Dict[str, Any]:
258260
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-
259261
instantiation.
@@ -262,6 +264,8 @@ def _dataloader_init_kwargs_resolve_sampler(
262264
Lightning can keep track of its indices.
263265
264266
"""
267+
from lightning.pytorch.trainer.states import RunningStage
268+
265269
is_predicting = mode == RunningStage.PREDICTING
266270
batch_sampler = getattr(dataloader, "batch_sampler")
267271
batch_sampler_cls = type(batch_sampler)

0 commit comments

Comments
 (0)