Skip to content

Commit 0300720

Browse files
committed
Moved changes to TrainerDataLoadingMixing
1 parent 67b637a commit 0300720

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class TrainerDataLoadingMixin(ABC):
4747
# this is just a summary on variables used in this abstract class,
4848
# the proper values/initialisation should be done in child class
4949
val_check_interval: float
50+
reload_dataloaders_every_n_epochs: int
5051
tpu_local_core_rank: int
5152
train_dataloader: DataLoader
5253
limit_train_batches: Union[int, float]
@@ -67,6 +68,9 @@ class TrainerDataLoadingMixin(ABC):
6768
accelerator: Accelerator
6869
call_hook: Callable
6970
_accelerator_connector: AcceleratorConnector
71+
current_epoch: int
72+
_last_train_dl_reload_epoch: int
73+
_last_val_dl_reload_epoch: int
7074

7175
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
7276
if not isinstance(dataloader, DataLoader):
@@ -277,6 +281,9 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
277281
category=PossibleUserWarning,
278282
)
279283

284+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
285+
self._last_train_dl_reload_epoch = self.current_epoch
286+
280287
def _reset_eval_dataloader(
281288
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
282289
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
@@ -372,6 +379,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
372379
RunningStage.VALIDATING, model=pl_module
373380
)
374381

382+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
383+
self._last_val_dl_reload_epoch = self.current_epoch
384+
375385
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
376386
"""Resets the test dataloader and determines the number of batches.
377387
@@ -459,3 +469,15 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader:
459469
dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler)
460470

461471
return dataloader
472+
473+
@property
474+
def _should_reload_train_dl(self) -> bool:
475+
"""Check if train dataloader should be reloaded."""
476+
n_epochs = self.reload_dataloaders_every_n_epochs
477+
return n_epochs and self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs
478+
479+
@property
480+
def _should_reload_val_dl(self) -> bool:
481+
"""Check if validation dataloader should be reloaded."""
482+
n_epochs = self.reload_dataloaders_every_n_epochs
483+
return n_epochs and self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,14 +1358,6 @@ def _run_sanity_check(self, ref_model):
13581358
# restore the previous stage when the sanity check if finished
13591359
self.state.stage = stage
13601360

1361-
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
1362-
self._last_train_dl_reload_epoch = self.current_epoch
1363-
return super().reset_train_dataloader(model=model)
1364-
1365-
def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
1366-
self._last_val_dl_reload_epoch = self.current_epoch
1367-
return super().reset_val_dataloader(model=model)
1368-
13691361
def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]:
13701362
if model_provided and ckpt_path is None:
13711363
# use passed model to function without loading weights
@@ -1850,18 +1842,6 @@ def progress_bar_dict(self) -> dict:
18501842
return self.progress_bar_callback.get_metrics(self, ref_model)
18511843
return self.progress_bar_metrics
18521844

1853-
@property
1854-
def _should_reload_train_dl(self) -> bool:
1855-
"""Check if train dataloader should be reloaded."""
1856-
n_epochs = self.reload_dataloaders_every_n_epochs
1857-
return n_epochs and self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs
1858-
1859-
@property
1860-
def _should_reload_val_dl(self) -> bool:
1861-
"""Check if validation dataloader should be reloaded."""
1862-
n_epochs = self.reload_dataloaders_every_n_epochs
1863-
return n_epochs and self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs
1864-
18651845
@property
18661846
def enable_validation(self) -> bool:
18671847
"""Check if we should run validation during training."""

0 commit comments

Comments
 (0)