@@ -47,6 +47,7 @@ class TrainerDataLoadingMixin(ABC):
4747    # this is just a summary on variables used in this abstract class, 
4848    #  the proper values/initialisation should be done in child class 
4949    val_check_interval : float 
50+     reload_dataloaders_every_n_epochs : int 
5051    tpu_local_core_rank : int 
5152    train_dataloader : DataLoader 
5253    limit_train_batches : Union [int , float ]
@@ -67,6 +68,9 @@ class TrainerDataLoadingMixin(ABC):
6768    accelerator : Accelerator 
6869    call_hook : Callable 
6970    _accelerator_connector : AcceleratorConnector 
71+     current_epoch : int 
72+     _last_train_dl_reload_epoch : int 
73+     _last_val_dl_reload_epoch : int 
7074
7175    def  _worker_check (self , dataloader : DataLoader , name : str ) ->  None :
7276        if  not  isinstance (dataloader , DataLoader ):
@@ -277,6 +281,9 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
277281                category = PossibleUserWarning ,
278282            )
279283
284+         # store epoch of dataloader reset for reload_dataloaders_every_n_epochs 
285+         self ._last_train_dl_reload_epoch  =  self .current_epoch 
286+ 
280287    def  _reset_eval_dataloader (
281288        self , mode : RunningStage , model : Optional ["pl.LightningModule" ] =  None 
282289    ) ->  Tuple [List [Union [int , float ]], List [DataLoader ]]:
@@ -372,6 +379,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
372379                RunningStage .VALIDATING , model = pl_module 
373380            )
374381
382+             # store epoch of dataloader reset for reload_dataloaders_every_n_epochs 
383+             self ._last_val_dl_reload_epoch  =  self .current_epoch 
384+ 
375385    def  reset_test_dataloader (self , model : Optional ["pl.LightningModule" ] =  None ) ->  None :
376386        """Resets the test dataloader and determines the number of batches. 
377387
@@ -459,3 +469,15 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader:
459469            dataloader  =  apply_to_collection (dataloader , DataLoader , replace_sampler )
460470
461471        return  dataloader 
472+ 
473+     @property  
474+     def  _should_reload_train_dl (self ) ->  bool :
475+         """Check if train dataloader should be reloaded.""" 
476+         n_epochs  =  self .reload_dataloaders_every_n_epochs 
477+         return  n_epochs  and  self .current_epoch  -  self ._last_train_dl_reload_epoch  >=  n_epochs 
478+ 
479+     @property  
480+     def  _should_reload_val_dl (self ) ->  bool :
481+         """Check if validation dataloader should be reloaded.""" 
482+         n_epochs  =  self .reload_dataloaders_every_n_epochs 
483+         return  n_epochs  and  self .current_epoch  -  self ._last_val_dl_reload_epoch  >=  n_epochs 
0 commit comments