1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15+ from dataclasses import dataclass
1516from functools import partial
16- from typing import Callable , Iterable , Optional , Union
17+ from typing import Iterable , Optional , Union
1718
1819import pytorch_lightning as pl
1920from pytorch_lightning .utilities import rank_zero_deprecation
@@ -47,6 +48,11 @@ def __init__(
4748 self .test_data_fetcher = test_data_fetcher
4849 self .sanity_check_data_fetcher : Optional [AbstractDataFetcher ] = None
4950
51+ self ._train_dataloader_source = _DataLoaderSource (None , "" )
52+ self ._val_dataloader_source = _DataLoaderSource (None , "" )
53+ self ._test_dataloader_source = _DataLoaderSource (None , "" )
54+ self ._predict_dataloader_source = _DataLoaderSource (None , "" )
55+
5056 @property
5157 def evaluation_data_fetcher (self ) -> Optional [AbstractDataFetcher ]:
5258 if self .trainer .sanity_checking :
@@ -190,27 +196,23 @@ def attach_dataloaders(
190196 test_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
191197 predict_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
192198 ) -> None :
193- # when dataloader is passed via fit, patch the train_dataloader
194- # functions to overwrite with these implementations
195- if train_dataloaders is not None :
196- self .trainer .train_dataloader = None
197- train_dataloader = _PatchDataLoader (train_dataloaders , "train" )
198- train_dataloader .patch (model )
199-
200- if val_dataloaders is not None :
201- self .trainer .val_dataloaders = None
202- val_dataloader = _PatchDataLoader (val_dataloaders , "val" )
203- val_dataloader .patch (model )
204-
205- if test_dataloaders is not None :
206- self .trainer .test_dataloaders = None
207- test_dataloader = _PatchDataLoader (test_dataloaders , "test" )
208- test_dataloader .patch (model )
209-
210- if predict_dataloaders is not None :
211- self .trainer .predict_dataloaders = None
212- predict_dataloader = _PatchDataLoader (predict_dataloaders , "predict" )
213- predict_dataloader .patch (model )
199+ self .trainer .train_dataloader = None
200+ self .trainer .val_dataloaders = None
201+ self .trainer .test_dataloaders = None
202+ self .trainer .predict_dataloaders = None
203+
204+ self ._train_dataloader_source = _DataLoaderSource (
205+ train_dataloaders if train_dataloaders is not None else model , "train_dataloader"
206+ )
207+ self ._val_dataloader_source = _DataLoaderSource (
208+ val_dataloaders if val_dataloaders is not None else model , "val_dataloader"
209+ )
210+ self ._test_dataloader_source = _DataLoaderSource (
211+ test_dataloaders if test_dataloaders is not None else model , "test_dataloader"
212+ )
213+ self ._predict_dataloader_source = _DataLoaderSource (
214+ predict_dataloaders if predict_dataloaders is not None else model , "predict_dataloader"
215+ )
214216
215217 def attach_datamodule (
216218 self , model : "pl.LightningModule" , datamodule : Optional ["pl.LightningDataModule" ] = None
@@ -219,11 +221,10 @@ def attach_datamodule(
219221 if datamodule is None :
220222 return
221223
222- # Override loader hooks
223- dl_methods = ("train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" )
224- for method in dl_methods :
225- if is_overridden (method , datamodule ):
226- setattr (model , method , getattr (datamodule , method ))
224+ self ._train_dataloader_source = _DataLoaderSource (datamodule , "train_dataloader" )
225+ self ._val_dataloader_source = _DataLoaderSource (datamodule , "val_dataloader" )
226+ self ._test_dataloader_source = _DataLoaderSource (datamodule , "test_dataloader" )
227+ self ._predict_dataloader_source = _DataLoaderSource (datamodule , "predict_dataloader" )
227228
228229 # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
229230 batch_transfer_hooks = ("on_before_batch_transfer" , "transfer_batch_to_device" , "on_after_batch_transfer" )
@@ -238,13 +239,6 @@ def attach_datamodule(
238239 if hasattr (datamodule , "data_pipeline" ):
239240 model .data_pipeline = datamodule .data_pipeline
240241
241- @staticmethod
242- def detach_data (model : "pl.LightningModule" ) -> None :
243- for stage in ("train" , "val" , "test" , "predict" ):
244- loader = getattr (model , f"{ stage } _dataloader" , None )
245- if isinstance (loader , _PatchDataLoader ):
246- loader .unpatch (model )
247-
248242 def teardown (self ) -> None :
249243 if self .train_data_fetcher :
250244 self .train_data_fetcher .teardown ()
@@ -260,32 +254,56 @@ def teardown(self) -> None:
260254 self .sanity_check_data_fetcher = None
261255
262256
263- class _PatchDataLoader :
264- r"""
265- Callable object for patching dataloaders passed into trainer.fit().
266- Use this class to override model.*_dataloader() and be pickle-compatible.
257+ @dataclass
258+ class _DataLoaderSource :
259+ """Stores the information where the dataloaders come from.
260+
261+ The source can be
267262
268- Args:
269- dataloader: Dataloader object to return when called.
263+ 1. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.lightning.LightningModule`,
264+ 2. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
265+ 3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof.
266+
267+ Arguments:
268+ instance: A LightningModule, LightningDataModule, or (a collection of) dataloader(s).
269+ name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook
270+ that returns the desired dataloader(s).
270271 """
271272
272- def __init__ (self , dataloader : Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ], stage : str ) -> None :
273- self .dataloader = dataloader
273+ instance : Optional [Union [TRAIN_DATALOADERS , EVAL_DATALOADERS , "pl.LightningModule" , "pl.LightningDataModule" ]]
274+ name : str
275+
276+ def dataloader (self ) -> Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ]:
277+ """Returns the dataloader from the source.
278+
279+ If the source is a module, the method with the corresponding :attr:`name` gets called.
280+ """
281+ from pytorch_lightning import LightningDataModule , LightningModule # prevent cyclic import
282+
283+ if not self .name :
284+ return self .instance
285+
286+ if isinstance (self .instance , LightningModule ):
287+ return self .instance .trainer .call_hook (self .name , pl_module = self .instance )
288+
289+ if isinstance (self .instance , LightningDataModule ):
290+ method = getattr (self .instance , self .name )
291+ return method ()
292+
293+ return self .instance
294+
295+ def is_defined (self ) -> bool :
296+ """Returns whether the source dataloader can be retrieved or not.
274297
275- # cannot pickle __code__ so cannot verify if PatchDataloader
276- # exists which shows dataloader methods have been overwritten.
277- # so, we hack it by using the string representation
278- self .patch_loader_code = str (self .__call__ .__code__ )
279- self ._old_loader : Optional [Callable ] = None
280- self .stage = stage
298+ If the source is a module it checks that the method with given :attr:`name` is overridden.
299+ """
300+ return not self .is_module () or is_overridden (self .name , self .instance )
281301
282- def __call__ (self ) -> Union [ TRAIN_DATALOADERS , EVAL_DATALOADERS ] :
283- return self . dataloader
302+ def is_module (self ) -> bool :
303+ """Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
284304
285- def patch ( self , model : "pl.LightningModule" ) -> None :
286- self . _old_loader = getattr ( model , self . stage + "_dataloader" )
287- setattr ( model , self . stage + "_dataloader" , self )
305+ It does not check whether ``*_dataloader`` methods are actually overridden.
306+ """
307+ from pytorch_lightning import LightningDataModule , LightningModule # prevent cyclic import
288308
289- def unpatch (self , model : "pl.LightningModule" ) -> None :
290- setattr (model , self .stage + "_dataloader" , self ._old_loader )
291- self ._old_loader = None
309+ return isinstance (self .instance , (LightningModule , LightningDataModule ))
0 commit comments