2222from pytorch_lightning .core .hooks import CheckpointHooks , DataHooks
2323from pytorch_lightning .core .mixins import HyperparametersMixin
2424from pytorch_lightning .core .saving import _load_from_checkpoint
25- from pytorch_lightning .utilities .argparse import add_argparse_args , from_argparse_args , get_init_arguments_and_types
26- from pytorch_lightning .utilities .types import _PATH
25+ from pytorch_lightning .utilities .argparse import (
26+ add_argparse_args ,
27+ from_argparse_args ,
28+ get_init_arguments_and_types ,
29+ parse_argparser ,
30+ )
31+ from pytorch_lightning .utilities .types import _ADD_ARGPARSE_RETURN , _PATH , EVAL_DATALOADERS , TRAIN_DATALOADERS
2732
2833
2934class LightningDataModule (CheckpointHooks , DataHooks , HyperparametersMixin ):
@@ -55,7 +60,7 @@ def teardown(self):
5560 # called on every process in DDP
5661 """
5762
58- name : str = ...
63+ name : Optional [ str ] = None
5964 CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
6065 CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
6166 CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"
@@ -66,7 +71,7 @@ def __init__(self) -> None:
6671 self .trainer : Optional ["pl.Trainer" ] = None
6772
6873 @classmethod
69- def add_argparse_args (cls , parent_parser : ArgumentParser , ** kwargs ) -> ArgumentParser :
74+ def add_argparse_args (cls , parent_parser : ArgumentParser , ** kwargs : Any ) -> _ADD_ARGPARSE_RETURN :
7075 """Extends existing argparse by default `LightningDataModule` attributes.
7176
7277 Example::
@@ -77,7 +82,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentP
7782 return add_argparse_args (cls , parent_parser , ** kwargs )
7883
7984 @classmethod
80- def from_argparse_args (cls , args : Union [Namespace , ArgumentParser ], ** kwargs ):
85+ def from_argparse_args (
86+ cls , args : Union [Namespace , ArgumentParser ], ** kwargs : Any
87+ ) -> Union ["pl.LightningDataModule" , "pl.Trainer" ]:
8188 """Create an instance from CLI arguments.
8289
8390 Args:
@@ -92,6 +99,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
9299 """
93100 return from_argparse_args (cls , args , ** kwargs )
94101
102+ @classmethod
103+ def parse_argparser (cls , arg_parser : Union [ArgumentParser , Namespace ]) -> Namespace :
104+ return parse_argparser (cls , arg_parser )
105+
95106 @classmethod
96107 def get_init_arguments_and_types (cls ) -> List [Tuple [str , Tuple , Any ]]:
97108 r"""Scans the DataModule signature and returns argument names, types and default values.
@@ -102,6 +113,15 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
102113 """
103114 return get_init_arguments_and_types (cls )
104115
116+ @classmethod
117+ def get_deprecated_arg_names (cls ) -> List :
118+ """Returns a list with deprecated DataModule arguments."""
119+ depr_arg_names : List [str ] = []
120+ for name , val in cls .__dict__ .items ():
121+ if name .startswith ("DEPRECATED" ) and isinstance (val , (tuple , list )):
122+ depr_arg_names .extend (val )
123+ return depr_arg_names
124+
105125 @classmethod
106126 def from_datasets (
107127 cls ,
@@ -112,7 +132,7 @@ def from_datasets(
112132 batch_size : int = 1 ,
113133 num_workers : int = 0 ,
114134 ** datamodule_kwargs : Any ,
115- ):
135+ ) -> "LightningDataModule" :
116136 r"""
117137 Create an instance from torch.utils.data.Dataset.
118138
@@ -133,24 +153,32 @@ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
133153 shuffle &= not isinstance (ds , IterableDataset )
134154 return DataLoader (ds , batch_size = batch_size , shuffle = shuffle , num_workers = num_workers , pin_memory = True )
135155
136- def train_dataloader ():
156+ def train_dataloader () -> TRAIN_DATALOADERS :
157+ assert train_dataset
158+
137159 if isinstance (train_dataset , Mapping ):
138160 return {key : dataloader (ds , shuffle = True ) for key , ds in train_dataset .items ()}
139161 if isinstance (train_dataset , Sequence ):
140162 return [dataloader (ds , shuffle = True ) for ds in train_dataset ]
141163 return dataloader (train_dataset , shuffle = True )
142164
143- def val_dataloader ():
165+ def val_dataloader () -> EVAL_DATALOADERS :
166+ assert val_dataset
167+
144168 if isinstance (val_dataset , Sequence ):
145169 return [dataloader (ds ) for ds in val_dataset ]
146170 return dataloader (val_dataset )
147171
148- def test_dataloader ():
172+ def test_dataloader () -> EVAL_DATALOADERS :
173+ assert test_dataset
174+
149175 if isinstance (test_dataset , Sequence ):
150176 return [dataloader (ds ) for ds in test_dataset ]
151177 return dataloader (test_dataset )
152178
153- def predict_dataloader ():
179+ def predict_dataloader () -> EVAL_DATALOADERS :
180+ assert predict_dataset
181+
154182 if isinstance (predict_dataset , Sequence ):
155183 return [dataloader (ds ) for ds in predict_dataset ]
156184 return dataloader (predict_dataset )
@@ -161,19 +189,19 @@ def predict_dataloader():
161189 if accepts_kwargs :
162190 special_kwargs = candidate_kwargs
163191 else :
164- accepted_params = set (accepted_params )
165- accepted_params .discard ("self" )
166- special_kwargs = {k : v for k , v in candidate_kwargs .items () if k in accepted_params }
192+ accepted_param_names = set (accepted_params )
193+ accepted_param_names .discard ("self" )
194+ special_kwargs = {k : v for k , v in candidate_kwargs .items () if k in accepted_param_names }
167195
168196 datamodule = cls (** datamodule_kwargs , ** special_kwargs )
169197 if train_dataset is not None :
170- datamodule .train_dataloader = train_dataloader
198+ datamodule .train_dataloader = train_dataloader # type: ignore[assignment]
171199 if val_dataset is not None :
172- datamodule .val_dataloader = val_dataloader
200+ datamodule .val_dataloader = val_dataloader # type: ignore[assignment]
173201 if test_dataset is not None :
174- datamodule .test_dataloader = test_dataloader
202+ datamodule .test_dataloader = test_dataloader # type: ignore[assignment]
175203 if predict_dataset is not None :
176- datamodule .predict_dataloader = predict_dataloader
204+ datamodule .predict_dataloader = predict_dataloader # type: ignore[assignment]
177205 return datamodule
178206
179207 def state_dict (self ) -> Dict [str , Any ]:
@@ -197,8 +225,8 @@ def load_from_checkpoint(
197225 cls ,
198226 checkpoint_path : Union [_PATH , IO ],
199227 hparams_file : Optional [_PATH ] = None ,
200- ** kwargs ,
201- ):
228+ ** kwargs : Any ,
229+ ) -> Union [ "pl.LightningModule" , "pl.LightningDataModule" ] :
202230 r"""
203231 Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
204232 it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
0 commit comments