1414import multiprocessing
1515import os
1616from dataclasses import dataclass , field
17- from typing import Any , Collection , List , Optional , Tuple , Union
17+ from typing import Any , Iterable , List , Optional , Tuple , Union
1818from weakref import proxy
1919
2020from torch .utils .data import BatchSampler , DataLoader , Sampler , SequentialSampler
@@ -55,7 +55,7 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
5555 self ._test_dataloader_source = _DataLoaderSource (None , "" )
5656 self ._predict_dataloader_source = _DataLoaderSource (None , "" )
5757
58- self ._datahook_selector = _DataHookSelector ( None , None )
58+ self ._datahook_selector : Optional [ _DataHookSelector ] = None
5959
6060 @property
6161 def _should_reload_train_dl (self ) -> bool :
@@ -230,7 +230,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
230230 category = PossibleUserWarning ,
231231 )
232232
233- def _requires_distributed_sampler (self , dataloader ) -> bool :
233+ def _requires_distributed_sampler (self , dataloader : DataLoader ) -> bool :
234234 return (
235235 self .trainer ._accelerator_connector .replace_sampler_ddp
236236 and self .trainer ._accelerator_connector .is_distributed
@@ -292,14 +292,18 @@ def _prepare_dataloader(
292292
293293 return dataloader
294294
295- def _resolve_sampler (self , dataloader : DataLoader , shuffle : bool , mode : Optional [RunningStage ] = None ) -> Sampler :
295+ def _resolve_sampler (
296+ self , dataloader : DataLoader , shuffle : bool , mode : Optional [RunningStage ] = None
297+ ) -> Union [Sampler , Iterable ]:
296298 if self ._requires_distributed_sampler (dataloader ):
299+ distributed_sampler_kwargs = self .trainer .distributed_sampler_kwargs
300+ assert distributed_sampler_kwargs is not None
297301 sampler = self ._get_distributed_sampler (
298302 dataloader ,
299303 shuffle ,
300304 mode = mode ,
301305 overfit_batches = self .trainer .overfit_batches ,
302- ** self . trainer . distributed_sampler_kwargs ,
306+ ** distributed_sampler_kwargs ,
303307 )
304308
305309 # update docs too once this is resolved
@@ -357,7 +361,7 @@ def _reset_eval_dataloader(
357361 dataloaders = self ._resolve_overfit_batches (dataloaders , mode )
358362
359363 if not isinstance (dataloaders , list ):
360- dataloaders = [dataloaders ]
364+ dataloaders = [dataloaders ] # type: ignore[assignment]
361365
362366 if any (dl is None for dl in dataloaders ):
363367 rank_zero_warn ("One of given dataloaders is None and it will be skipped." )
@@ -426,7 +430,7 @@ def _reset_eval_dataloader(
426430
427431 return loader_num_batches , dataloaders
428432
429- def _request_dataloader (self , stage : RunningStage ) -> Union [ DataLoader , List [ DataLoader ]] :
433+ def _request_dataloader (self , stage : RunningStage ) -> TRAIN_DATALOADERS :
430434 """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
431435
432436 Returns:
@@ -447,10 +451,12 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
447451 return dataloader
448452
449453 @staticmethod
450- def _resolve_overfit_batches (dataloaders : Collection [DataLoader ], mode : RunningStage ) -> Collection [DataLoader ]:
454+ def _resolve_overfit_batches (
455+ dataloaders : Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ], mode : RunningStage
456+ ) -> Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ]:
451457 all_have_sequential_sampler = True
452458
453- def resolve_has_no_sequential_sampler (dataloader : DataLoader ):
459+ def resolve_has_no_sequential_sampler (dataloader : DataLoader ) -> None :
454460 nonlocal all_have_sequential_sampler
455461 all_have_sequential_sampler = all_have_sequential_sampler & isinstance (
456462 dataloader .sampler , SequentialSampler
@@ -460,19 +466,23 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader):
460466
461467 if not all_have_sequential_sampler :
462468 rank_zero_warn (
463- "You requested to overfit but enabled training dataloader shuffling."
469+ f "You requested to overfit but enabled { mode . dataloader_prefix } dataloader shuffling."
464470 f" We are turning off the { mode .dataloader_prefix } dataloader shuffling for you."
465471 )
466472
467473 def replace_sampler (dataloader : DataLoader ) -> DataLoader :
468- return _update_dataloader (dataloader , sampler = SequentialSampler (dataloader .dataset ), mode = mode )
474+ return _update_dataloader (
475+ dataloader ,
476+ sampler = SequentialSampler (dataloader .dataset ), # type: ignore[arg-type]
477+ mode = mode ,
478+ )
469479
470480 dataloaders = apply_to_collection (dataloaders , DataLoader , replace_sampler )
471481
472482 return dataloaders
473483
474484 @staticmethod
475- def _check_eval_shuffling (dataloader , mode ) :
485+ def _check_eval_shuffling (dataloader : DataLoader , mode : RunningStage ) -> None :
476486 # limit this warning only for samplers assigned automatically when shuffle is set
477487 if _is_dataloader_shuffled (dataloader ):
478488 rank_zero_warn (
@@ -506,18 +516,14 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
506516
507517 If the source is a module, the method with the corresponding :attr:`name` gets called.
508518 """
509- from pytorch_lightning import LightningDataModule , LightningModule # prevent cyclic import
510-
511- if not self .name :
512- return self .instance
513-
514- if isinstance (self .instance , LightningModule ):
519+ if isinstance (self .instance , pl .LightningModule ):
515520 return self .instance .trainer ._call_lightning_module_hook (self .name , pl_module = self .instance )
516521
517- if isinstance (self .instance , LightningDataModule ):
522+ if isinstance (self .instance , pl . LightningDataModule ):
518523 method = getattr (self .instance , self .name )
519524 return method ()
520525
526+ assert self .instance is not None
521527 return self .instance
522528
523529 def is_defined (self ) -> bool :
@@ -532,9 +538,7 @@ def is_module(self) -> bool:
532538
533539 It does not check whether ``*_dataloader`` methods are actually overridden.
534540 """
535- from pytorch_lightning import LightningDataModule , LightningModule # prevent cyclic import
536-
537- return isinstance (self .instance , (LightningModule , LightningDataModule ))
541+ return isinstance (self .instance , (pl .LightningModule , pl .LightningDataModule ))
538542
539543
540544@dataclass
@@ -553,7 +557,7 @@ class _DataHookSelector:
553557
554558 model : "pl.LightningModule"
555559 datamodule : Optional ["pl.LightningDataModule" ]
556- _valid_hooks : Tuple [str ] = field (
560+ _valid_hooks : Tuple [str , ... ] = field (
557561 default = ("on_before_batch_transfer" , "transfer_batch_to_device" , "on_after_batch_transfer" )
558562 )
559563
0 commit comments