From e2fbbd0bbcdf67c22ec008316ed2e174aeebecbb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 04:17:09 -0700 Subject: [PATCH 01/22] bugfix-dataloading --- pytorch_lightning/callbacks/progress.py | 17 ++- pytorch_lightning/trainer/data_loading.py | 11 ++ pytorch_lightning/trainer/training_loop.py | 38 ++++-- tests/helpers/boring_model.py | 27 +++- tests/trainer/test_dataloaders.py | 145 +++++++++++++++++---- 5 files changed, 195 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 5a76e5eb97331..b8fa06f649e1a 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -20,6 +20,7 @@ """ import importlib import io +import math import os import sys @@ -402,12 +403,14 @@ def on_train_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - reset(self.main_progress_bar, total_batches) + reset(self.main_progress_bar, convert_inf(total_batches)) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): + total_batches = self.total_train_batches + self.total_val_batches + total_batches = convert_inf(total_batches) + if self._should_update(self.train_batch_idx, total_batches): self._update_bar(self.main_progress_bar) self.main_progress_bar.set_postfix(trainer.progress_bar_dict) @@ -418,11 +421,11 @@ def on_validation_start(self, trainer, pl_module): else: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() - reset(self.val_progress_bar, self.total_val_batches) + reset(self.val_progress_bar, convert_inf(self.total_val_batches)) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.val_batch_idx, self.total_val_batches): + if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): self._update_bar(self.val_progress_bar) self._update_bar(self.main_progress_bar) @@ -479,7 +482,7 @@ def print( s = sep.join(map(str, args)) active_progress_bar.write(s, end=end, file=file, nolock=nolock) - def _should_update(self, current, total): + def _should_update(self, current, total) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def _update_bar(self, bar: Optional[tqdm]) -> None: @@ -496,8 +499,8 @@ def _update_bar(self, bar: Optional[tqdm]) -> None: def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: - """ The tqdm doesn't support inf values. We have to convert it to None. """ - if x == float('inf'): + """ The tqdm doesn't support inf/nan values. We have to convert it to None. """ + if x is None or math.isinf(x) or math.isnan(x): return None return x diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 846e267690e94..14064451b8276 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import logging import multiprocessing import os from abc import ABC @@ -262,10 +263,16 @@ def reset_train_dataloader(self, model: LightningModule) -> None: self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') + logging.error( + f"1: self.num_training_batches={self.num_training_batches}, limit_train_batches={self.limit_train_batches}" + ) + if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) + logging.error(f"2: self.num_training_batches={self.num_training_batches}") elif self.num_training_batches != float('inf'): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) + logging.error(f"3: self.num_training_batches={self.num_training_batches}") elif self.limit_train_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' @@ -277,7 +284,9 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): + logging.error(f"4: self.val_check_interval={self.val_check_interval}") self.val_check_batch = self.val_check_interval + logging.error(f"5: self.val_check_batch={self.val_check_batch}") if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' @@ -288,6 +297,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') + logging.error(f"6: self.val_check_batch={self.val_check_batch}") else: raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' @@ -297,6 +307,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) + logging.error(f"7: self.val_check_batch={self.val_check_batch}") def _reset_eval_dataloader( self, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b34452d5cc7eb..a34d775e0b3b7 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -486,7 +486,7 @@ def run_training_epoch(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) + should_check_val = self._should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.validating = True self.trainer.run_evaluation() @@ -535,7 +535,7 @@ def run_training_epoch(self): # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) - should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval @@ -825,19 +825,31 @@ def should_accumulate(self): is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) - def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): - # decide if we should run validation - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - can_check_val = self.trainer.enable_validation and is_val_check_epoch - is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: + """ Decide if we should run validation. """ + if not self.trainer.enable_validation: + return False + if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + return False + + # val_check_interval is inf for iterable datasets with no length defined + # TODO: the training loop should maintain this logic + # around limit_train_batches and val_check_interval + # not the dataloading mixin + if self.trainer.val_check_batch != float('inf'): + is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + else: + is_val_check_batch = (batch_idx + 1) % self.trainer.num_training_batches == 0 - should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop - or is_last_batch_for_infinite_dataset - ) if on_epoch else (is_val_check_batch and not epoch_end_val_check) + epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - return should_check_val and can_check_val + if on_epoch: + return ( + is_val_check_batch and epoch_end_val_check + ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + else: + return is_val_check_batch and not epoch_end_val_check def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index 2e7c626306c36..eb81baeb2c29d 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -14,7 +14,7 @@ from typing import Optional import torch -from torch.utils.data import DataLoader, Dataset, Subset +from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset from pytorch_lightning import LightningDataModule, LightningModule @@ -60,6 +60,31 @@ def __len__(self): return self.len +class RandomIterableDataset(IterableDataset): + + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self): + return self.count + + class BoringModel(LightningModule): def __init__(self): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 240ddfa37b46e..8360a9fca2f43 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen from tests.helpers.runif import RunIf @@ -234,59 +234,160 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - (0.0, 0.0, 0.0), + (0., 0., 0.), (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" - model = EvalModelTemplate() - model.train_dataloader = model.train_dataloader__infinite - model.val_dataloader = model.val_dataloader__infinite - model.test_dataloader = model.test_dataloader__infinite + class DummyModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log("loss", self.global_step) + return super().training_step(batch, batch_idx) + + def validation_epoch_end(self, outputs): + self.log("val_log", self.current_epoch) + + import logging + + class EpochCounter(Callback): + + def __init__(self): + super().__init__() + self.train_epoch_count = 0 + self.val_epoch_count = 0 + self.test_epoch_count = 0 + + def on_train_epoch_start(self, trainer, pl_module): + logging.error("on train epoch start") + if not trainer.sanity_checking: + self.train_epoch_count += 1 + + def on_validation_epoch_start(self, trainer, pl_module): + logging.error("on val epoch start") + if not trainer.sanity_checking: + self.val_epoch_count += 1 + + def on_test_epoch_start(self, trainer, pl_module): + logging.error("on test epoch start") + if not trainer.sanity_checking: + self.test_epoch_count += 1 + + ckpt_callback = ModelCheckpoint(monitor=f"val_log", save_top_k=1, mode="max", verbose=False) + epoch_cb = EpochCounter() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, + callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) + model = DummyModel() - trainer.fit(model) + batch_size = 8 + train_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) + val_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) + test_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) + + num_batches = 128 / batch_size + for dl in (train_dl, val_dl, test_dl): + if has_len(dl): + assert len(dl) == num_batches + else: + assert sum(1 for _ in dl) == num_batches + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) + assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) + assert epoch_cb.val_epoch_count == int(limit_val_batches > 0) - trainer.test(ckpt_path=None) + trainer.test(model, test_dataloaders=test_dl) assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf')) + assert epoch_cb.test_epoch_count == int(limit_test_batches > 0) -@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - (0, 0, 0), - (10, 10, 10), +@pytest.mark.parametrize(['dataset', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (RandomDataset(32, 128), 0, 0, 0), + (RandomDataset(32, 128), 10, 10, 10), + (RandomIterableDataset(32, 128), 0, 0, 0), + (RandomIterableDataset(32, 128), 10, 10, 10), + (RandomIterableDatasetWithLen(32, 128), 0, 0, 0), + (RandomIterableDatasetWithLen(32, 128), 10, 10, 10), ]) -def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): +def test_datasets_dataloaders_with_limit_num_batches( + tmpdir, dataset, limit_train_batches, limit_val_batches, limit_test_batches +): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" - model = EvalModelTemplate() - model.train_dataloader = model.train_dataloader__infinite - model.val_dataloader = model.val_dataloader__infinite - model.test_dataloader = model.test_dataloader__infinite + class BasicModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log("loss", self.global_step) + return super().training_step(batch, batch_idx) + + def validation_epoch_end(self, outputs): + self.log("val_log", self.current_epoch) + + import logging + + class EpochCounter(Callback): + + def __init__(self): + super().__init__() + self.train_epoch_count = 0 + self.val_epoch_count = 0 + self.test_epoch_count = 0 + + def on_train_epoch_start(self, trainer, pl_module): + logging.error("on train epoch start") + if not trainer.sanity_checking: + self.train_epoch_count += 1 + + def on_validation_epoch_start(self, trainer, pl_module): + logging.error("on val epoch start") + if not trainer.sanity_checking: + self.val_epoch_count += 1 + + def on_test_epoch_start(self, trainer, pl_module): + logging.error("on test epoch start") + if not trainer.sanity_checking: + self.test_epoch_count += 1 + + ckpt_callback = ModelCheckpoint(monitor=f"val_log", save_top_k=1, mode="max", verbose=False) + epoch_cb = EpochCounter() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, + callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, - limit_val_batches=limit_val_batches, + limit_val_batches=limit_val_batches, # If added changes error - Checkpoint `MisconfigurationError` if provided. limit_test_batches=limit_test_batches, ) + model = BasicModel() - trainer.fit(model) + batch_size = 8 + train_dl = DataLoader(dataset=dataset, batch_size=batch_size) + val_dl = DataLoader(dataset=dataset, batch_size=batch_size) + test_dl = DataLoader(dataset=dataset, batch_size=batch_size) + + num_batches = 128 / batch_size + for dl in (train_dl, val_dl, test_dl): + if has_len(dl): + assert len(dl) == num_batches + else: + assert sum(1 for _ in dl) == num_batches + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches + assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) + assert epoch_cb.val_epoch_count == int(limit_val_batches > 0) - trainer.test(ckpt_path=None) + trainer.test(model, test_dataloaders=test_dl) assert trainer.num_test_batches[0] == limit_test_batches + assert epoch_cb.test_epoch_count == int(limit_test_batches > 0) @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ From d1fb1731a18c92e2da3da2c6fd163e889ae50e67 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 04:23:03 -0700 Subject: [PATCH 02/22] rm-logs --- pytorch_lightning/trainer/data_loading.py | 11 --- tests/trainer/test_dataloaders.py | 93 +++++++---------------- 2 files changed, 28 insertions(+), 76 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 14064451b8276..846e267690e94 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import logging import multiprocessing import os from abc import ABC @@ -263,16 +262,10 @@ def reset_train_dataloader(self, model: LightningModule) -> None: self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') - logging.error( - f"1: self.num_training_batches={self.num_training_batches}, limit_train_batches={self.limit_train_batches}" - ) - if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) - logging.error(f"2: self.num_training_batches={self.num_training_batches}") elif self.num_training_batches != float('inf'): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) - logging.error(f"3: self.num_training_batches={self.num_training_batches}") elif self.limit_train_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' @@ -284,9 +277,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): - logging.error(f"4: self.val_check_interval={self.val_check_interval}") self.val_check_batch = self.val_check_interval - logging.error(f"5: self.val_check_batch={self.val_check_batch}") if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' @@ -297,7 +288,6 @@ def reset_train_dataloader(self, model: LightningModule) -> None: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') - logging.error(f"6: self.val_check_batch={self.val_check_batch}") else: raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' @@ -307,7 +297,6 @@ def reset_train_dataloader(self, model: LightningModule) -> None: else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - logging.error(f"7: self.val_check_batch={self.val_check_batch}") def _reset_eval_dataloader( self, diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 8360a9fca2f43..7d89dc400a221 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -233,45 +233,42 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): assert len(trainer.test_dataloaders) == n -@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - (0., 0., 0.), - (1.0, 1.0, 1.0), -]) -def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): +class DummyModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log("loss", self.global_step) + return super().training_step(batch, batch_idx) - class DummyModel(BoringModel): + def validation_epoch_end(self, outputs): + self.log("val_log", self.current_epoch) - def training_step(self, batch, batch_idx): - self.log("loss", self.global_step) - return super().training_step(batch, batch_idx) - def validation_epoch_end(self, outputs): - self.log("val_log", self.current_epoch) +class EpochCounter(Callback): - import logging + def __init__(self): + super().__init__() + self.train_epoch_count = 0 + self.val_epoch_count = 0 + self.test_epoch_count = 0 - class EpochCounter(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking: + self.train_epoch_count += 1 - def __init__(self): - super().__init__() - self.train_epoch_count = 0 - self.val_epoch_count = 0 - self.test_epoch_count = 0 + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking: + self.val_epoch_count += 1 - def on_train_epoch_start(self, trainer, pl_module): - logging.error("on train epoch start") - if not trainer.sanity_checking: - self.train_epoch_count += 1 + def on_test_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking: + self.test_epoch_count += 1 - def on_validation_epoch_start(self, trainer, pl_module): - logging.error("on val epoch start") - if not trainer.sanity_checking: - self.val_epoch_count += 1 - def on_test_epoch_start(self, trainer, pl_module): - logging.error("on test epoch start") - if not trainer.sanity_checking: - self.test_epoch_count += 1 +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (0., 0., 0.), + (1.0, 1.0, 1.0), +]) +def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): ckpt_callback = ModelCheckpoint(monitor=f"val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = EpochCounter() @@ -321,47 +318,13 @@ def test_datasets_dataloaders_with_limit_num_batches( ): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" - class BasicModel(BoringModel): - - def training_step(self, batch, batch_idx): - self.log("loss", self.global_step) - return super().training_step(batch, batch_idx) - - def validation_epoch_end(self, outputs): - self.log("val_log", self.current_epoch) - - import logging - - class EpochCounter(Callback): - - def __init__(self): - super().__init__() - self.train_epoch_count = 0 - self.val_epoch_count = 0 - self.test_epoch_count = 0 - - def on_train_epoch_start(self, trainer, pl_module): - logging.error("on train epoch start") - if not trainer.sanity_checking: - self.train_epoch_count += 1 - - def on_validation_epoch_start(self, trainer, pl_module): - logging.error("on val epoch start") - if not trainer.sanity_checking: - self.val_epoch_count += 1 - - def on_test_epoch_start(self, trainer, pl_module): - logging.error("on test epoch start") - if not trainer.sanity_checking: - self.test_epoch_count += 1 - ckpt_callback = ModelCheckpoint(monitor=f"val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = EpochCounter() trainer = Trainer( max_epochs=1, callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, - limit_val_batches=limit_val_batches, # If added changes error - Checkpoint `MisconfigurationError` if provided. + limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) model = BasicModel() From 4fb0aa2463e72369b46941be5d59fadf1fff3322 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 04:25:52 -0700 Subject: [PATCH 03/22] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b629e9e72a9a5..a8fa3587894ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed dataloading for iterable datasets used with `limit_train_batches` ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) + + - Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207)) From 8058a3cc37bc92ae9a9587699ee0a50c8ebc99a1 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 04:30:56 -0700 Subject: [PATCH 04/22] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 7d89dc400a221..6f90828757af8 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -270,7 +270,7 @@ def on_test_epoch_start(self, trainer, pl_module): ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - ckpt_callback = ModelCheckpoint(monitor=f"val_log", save_top_k=1, mode="max", verbose=False) + ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = EpochCounter() trainer = Trainer( max_epochs=1, @@ -318,7 +318,7 @@ def test_datasets_dataloaders_with_limit_num_batches( ): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" - ckpt_callback = ModelCheckpoint(monitor=f"val_log", save_top_k=1, mode="max", verbose=False) + ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = EpochCounter() trainer = Trainer( max_epochs=1, @@ -327,7 +327,7 @@ def test_datasets_dataloaders_with_limit_num_batches( limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) - model = BasicModel() + model = DummyModel() batch_size = 8 train_dl = DataLoader(dataset=dataset, batch_size=batch_size) From a5f70c79ae02911c24d17e64707de933bbba300e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 09:05:05 -0700 Subject: [PATCH 05/22] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6f90828757af8..3b16dfdd00595 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -334,13 +334,6 @@ def test_datasets_dataloaders_with_limit_num_batches( val_dl = DataLoader(dataset=dataset, batch_size=batch_size) test_dl = DataLoader(dataset=dataset, batch_size=batch_size) - num_batches = 128 / batch_size - for dl in (train_dl, val_dl, test_dl): - if has_len(dl): - assert len(dl) == num_batches - else: - assert sum(1 for _ in dl) == num_batches - trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches From ec93877ecce38082bac86c71d8294bb9331391ee Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 11:14:47 -0700 Subject: [PATCH 06/22] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a34d775e0b3b7..7403d33122148 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -827,8 +827,11 @@ def should_accumulate(self): def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: """ Decide if we should run validation. """ + if not self.trainer.enable_validation: return False + + # check if this epoch eligible to run validation if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: return False From f1d9e4d7260409f2f47d92ecc2ecec5cd69e5310 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 11:28:59 -0700 Subject: [PATCH 07/22] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 3b16dfdd00595..4994de72143a6 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -320,8 +320,9 @@ def test_datasets_dataloaders_with_limit_num_batches( ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = EpochCounter() + epochs = 2 trainer = Trainer( - max_epochs=1, + max_epochs=epochs, callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, @@ -338,8 +339,8 @@ def test_datasets_dataloaders_with_limit_num_batches( assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches - assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) - assert epoch_cb.val_epoch_count == int(limit_val_batches > 0) + assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) + assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) trainer.test(model, test_dataloaders=test_dl) assert trainer.num_test_batches[0] == limit_test_batches From 069d5b657ec283c2bf330cf96aa418117dd0b2ef Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 12:53:33 -0700 Subject: [PATCH 08/22] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a8fa3587894ec..96791a8fdf69f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed possible NaN errors with progress bars when training with iterable datasets ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) + + - Fixed dataloading for iterable datasets used with `limit_train_batches` ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) From 3df11161404539d6dc092f667626ea2db4ede855 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 12:54:15 -0700 Subject: [PATCH 09/22] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96791a8fdf69f..97b72ea0cfa30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,7 +297,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed possible NaN errors with progress bars when training with iterable datasets ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) +- Fixed NaN errors with progress bars when training with iterable datasets ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) - Fixed dataloading for iterable datasets used with `limit_train_batches` ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) From 05656ef39a8d4b77ff1111a9996897a926421480 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 13:37:39 -0700 Subject: [PATCH 10/22] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 4994de72143a6..822678e8f3ec7 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -243,18 +243,33 @@ def validation_epoch_end(self, outputs): self.log("val_log", self.current_epoch) -class EpochCounter(Callback): +class Counter(Callback): def __init__(self): super().__init__() self.train_epoch_count = 0 self.val_epoch_count = 0 self.test_epoch_count = 0 + self.train_batches_seen = 0 + self.val_batches_seen = 0 + self.test_batches_seen = 0 + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if not trainer.sanity_checking: + self.train_batches_seen += 1 def on_train_epoch_start(self, trainer, pl_module): if not trainer.sanity_checking: self.train_epoch_count += 1 + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if not trainer.sanity_checking: + self.val_batches_seen += 1 + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if not trainer.sanity_checking: + self.test_batches_seen += 1 + def on_validation_epoch_start(self, trainer, pl_module): if not trainer.sanity_checking: self.val_epoch_count += 1 @@ -271,7 +286,7 @@ def on_test_epoch_start(self, trainer, pl_module): def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) - epoch_cb = EpochCounter() + epoch_cb = Counter() trainer = Trainer( max_epochs=1, callbacks=[epoch_cb, ckpt_callback], @@ -319,7 +334,7 @@ def test_datasets_dataloaders_with_limit_num_batches( """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) - epoch_cb = EpochCounter() + epoch_cb = Counter() epochs = 2 trainer = Trainer( max_epochs=epochs, @@ -340,7 +355,9 @@ def test_datasets_dataloaders_with_limit_num_batches( assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) + assert epoch_cb.train_batches_seen == limit_train_batches * epochs assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) + assert epoch_cb.val_batches_seen == limit_val_batches * epochs trainer.test(model, test_dataloaders=test_dl) assert trainer.num_test_batches[0] == limit_test_batches From 0044e0d1c7412a519de5b8a892c6e1a9b2918a48 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 15:32:51 -0700 Subject: [PATCH 11/22] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7403d33122148..3c3d8f17c6a86 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -455,6 +455,11 @@ def run_training_epoch(self): is_last_batch = None for batch_idx, (batch, is_last_batch) in train_dataloader: + + if isinstance(self.trainer.limit_train_batches, int): + limit_reached = (batch_idx + 1) % self.trainer.limit_train_batches == 0 + is_last_batch = is_last_batch or limit_reached + self.trainer.batch_idx = batch_idx self.trainer.is_last_batch = is_last_batch @@ -484,7 +489,7 @@ def run_training_epoch(self): self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- - # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # VALIDATE IF NEEDED # ----------------------------------------- should_check_val = self._should_check_val_fx(batch_idx, is_last_batch) if should_check_val: @@ -831,7 +836,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo if not self.trainer.enable_validation: return False - # check if this epoch eligible to run validation + # check if this epoch is eligible to run validation if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: return False @@ -839,10 +844,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo # TODO: the training loop should maintain this logic # around limit_train_batches and val_check_interval # not the dataloading mixin - if self.trainer.val_check_batch != float('inf'): + is_val_check_batch = False + if isinstance(self.trainer.limit_train_batches, int): + is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 + elif self.trainer.val_check_batch != float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - else: - is_val_check_batch = (batch_idx + 1) % self.trainer.num_training_batches == 0 epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") From 15e52be72c665362d058a42477ca91893b09b55a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 1 May 2021 15:46:51 -0700 Subject: [PATCH 12/22] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3c3d8f17c6a86..43a475d6e727c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -455,11 +455,6 @@ def run_training_epoch(self): is_last_batch = None for batch_idx, (batch, is_last_batch) in train_dataloader: - - if isinstance(self.trainer.limit_train_batches, int): - limit_reached = (batch_idx + 1) % self.trainer.limit_train_batches == 0 - is_last_batch = is_last_batch or limit_reached - self.trainer.batch_idx = batch_idx self.trainer.is_last_batch = is_last_batch @@ -845,7 +840,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo # around limit_train_batches and val_check_interval # not the dataloading mixin is_val_check_batch = False - if isinstance(self.trainer.limit_train_batches, int): + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 From c95f9968c3c374bcca57ef533cc8cc32a27eb497 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 11:03:09 -0700 Subject: [PATCH 13/22] comments --- CHANGELOG.md | 4 ++-- pytorch_lightning/trainer/training_loop.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97b72ea0cfa30..6573a11692eb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,10 +297,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed NaN errors with progress bars when training with iterable datasets ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) +- Fixed NaN errors in progress bars when training with iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) -- Fixed dataloading for iterable datasets used with `limit_train_batches` ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) +- Fixed validation being skipped for iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) - Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 43a475d6e727c..f2837063e0355 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -835,16 +835,17 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: return False - # val_check_interval is inf for iterable datasets with no length defined + # val_check_batch is inf for iterable datasets with no length defined # TODO: the training loop should maintain this logic - # around limit_train_batches and val_check_interval - # not the dataloading mixin + # around limit_train_batches and val_check_batch + # not the trainer dataloading mixin is_val_check_batch = False if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + # Note: num_training_batches is also inf for iterable datasets with no length defined epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") From 01dca9878d02fcfc62e0f2f07e9d2ad3351a35ba Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 12:41:19 -0700 Subject: [PATCH 14/22] address comments --- pytorch_lightning/callbacks/progress.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 4 +--- tests/trainer/test_dataloaders.py | 23 +++++++++++----------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index b8fa06f649e1a..34f644f15eaea 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -403,7 +403,7 @@ def on_train_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - reset(self.main_progress_bar, convert_inf(total_batches)) + reset(self.main_progress_bar, total_batches) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -421,7 +421,7 @@ def on_validation_start(self, trainer, pl_module): else: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() - reset(self.val_progress_bar, convert_inf(self.total_val_batches)) + reset(self.val_progress_bar, self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f2837063e0355..f96c17a0686ce 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -836,9 +836,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo return False # val_check_batch is inf for iterable datasets with no length defined - # TODO: the training loop should maintain this logic - # around limit_train_batches and val_check_batch - # not the trainer dataloading mixin + # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = False if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 822678e8f3ec7..f242c52dba6b6 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -255,28 +255,22 @@ def __init__(self): self.test_batches_seen = 0 def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - if not trainer.sanity_checking: - self.train_batches_seen += 1 + self.train_batches_seen += 1 def on_train_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking: - self.train_epoch_count += 1 + self.train_epoch_count += 1 def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - if not trainer.sanity_checking: - self.val_batches_seen += 1 + self.val_batches_seen += 1 def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - if not trainer.sanity_checking: - self.test_batches_seen += 1 + self.test_batches_seen += 1 def on_validation_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking: - self.val_epoch_count += 1 + self.val_epoch_count += 1 def on_test_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking: - self.test_epoch_count += 1 + self.test_epoch_count += 1 @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ @@ -284,10 +278,13 @@ def on_test_epoch_start(self, trainer, pl_module): (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = Counter() trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, max_epochs=1, callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, @@ -337,6 +334,8 @@ def test_datasets_dataloaders_with_limit_num_batches( epoch_cb = Counter() epochs = 2 trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, max_epochs=epochs, callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, From 89d284a56eacf138451285894ac824e30221229c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 13:59:32 -0700 Subject: [PATCH 15/22] more tests --- pytorch_lightning/callbacks/progress.py | 3 +- tests/trainer/test_dataloaders.py | 70 +++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 34f644f15eaea..e9088f9b61c5b 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -20,6 +20,7 @@ """ import importlib import io +import logging import math import os import sys @@ -398,7 +399,7 @@ def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches - if total_train_batches != float('inf'): + if total_train_batches != float('inf') and total_val_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f242c52dba6b6..57fa0654968e3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -317,6 +317,76 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, assert epoch_cb.test_epoch_count == int(limit_test_batches > 0) +@pytest.mark.parametrize(['dataset', 'limit_train_batches'], [ + (RandomDataset(32, 128), 0), + (RandomDataset(32, 128), 10), + (RandomIterableDataset(32, 128), 0), + (RandomIterableDataset(32, 128), 10), + (RandomIterableDatasetWithLen(32, 128), 0), + (RandomIterableDatasetWithLen(32, 128), 10), +]) +def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + + ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) + epoch_cb = Counter() + epochs = 2 + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=epochs, + callbacks=[epoch_cb, ckpt_callback], + limit_train_batches=limit_train_batches, + ) + model = DummyModel() + + batch_size = 8 + train_dl = DataLoader(dataset=dataset, batch_size=batch_size) + val_dl = DataLoader(dataset=dataset, batch_size=batch_size) + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == limit_train_batches + assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) + assert epoch_cb.train_batches_seen == limit_train_batches * epochs + + +@pytest.mark.parametrize(['dataset', 'limit_val_batches'], [ + (RandomDataset(32, 128), 0), + (RandomDataset(32, 128), 10), + (RandomIterableDataset(32, 128), 0), + (RandomIterableDataset(32, 128), 10), + (RandomIterableDatasetWithLen(32, 128), 0), + (RandomIterableDatasetWithLen(32, 128), 10), +]) +def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + + epoch_cb = Counter() + callbacks = [epoch_cb] + if limit_val_batches > 0: + callbacks.append(ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)) + epochs = 2 + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=epochs, + callbacks=callbacks, + limit_val_batches=limit_val_batches, + ) + model = DummyModel() + + batch_size = 8 + train_dl = DataLoader(dataset=dataset, batch_size=batch_size) + val_dl = DataLoader(dataset=dataset, batch_size=batch_size) + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.num_val_batches[0] == limit_val_batches + assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) + assert epoch_cb.val_batches_seen == limit_val_batches * epochs + + @pytest.mark.parametrize(['dataset', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ (RandomDataset(32, 128), 0, 0, 0), (RandomDataset(32, 128), 10, 10, 10), From 67bc801eea442d077291e7dedd648a506d6b77e6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 14:05:33 -0700 Subject: [PATCH 16/22] Update progress.py --- pytorch_lightning/callbacks/progress.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index e9088f9b61c5b..be9d2f44356f5 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -20,7 +20,6 @@ """ import importlib import io -import logging import math import os import sys From 62b546ace59a6469067498527e693313f5180bfb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 14:14:02 -0700 Subject: [PATCH 17/22] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 57fa0654968e3..ae444ca1b9b92 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -274,7 +274,7 @@ def on_test_epoch_start(self, trainer, pl_module): @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - (0., 0., 0.), + (0.0, 0.0, 0.0), (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): From cc1ba4a3fbd6f4b75f70201850a4ac09abe5916e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 14:46:32 -0700 Subject: [PATCH 18/22] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index ae444ca1b9b92..4774380e9d4f8 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -364,8 +364,11 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): epoch_cb = Counter() callbacks = [epoch_cb] + checkpoint_callback = True if limit_val_batches > 0: callbacks.append(ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)) + else: + checkpoint_callback = False epochs = 2 trainer = Trainer( default_root_dir=tmpdir, @@ -373,6 +376,7 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): max_epochs=epochs, callbacks=callbacks, limit_val_batches=limit_val_batches, + checkpoint_callback=checkpoint_callback, ) model = DummyModel() From a3f959c499f7ed0fadeb3d175ee8f5b6b50d9c0a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 15:03:07 -0700 Subject: [PATCH 19/22] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f96c17a0686ce..3883bc3706682 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -97,12 +97,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") From 49cd18c0239765e3ac14a54f92f7efdddbb1c004 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 15:03:42 -0700 Subject: [PATCH 20/22] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3883bc3706682..f96c17a0686ce 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -97,6 +97,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") From 5cd2482f7cd28677df1fe28bb4fe1709f45b010c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 15:41:42 -0700 Subject: [PATCH 21/22] test ckpt fix? --- .../callbacks/model_checkpoint.py | 33 +++++++++++++++++-- pytorch_lightning/trainer/training_loop.py | 6 ---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ac73efd67cb78..f63fba7151f55 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -188,6 +188,7 @@ def __init__( auto_insert_metric_name: bool = True, every_n_train_steps: Optional[int] = None, every_n_val_epochs: Optional[int] = None, + save_on_train_end: bool = False, period: Optional[int] = None, ): super().__init__() @@ -204,6 +205,7 @@ def __init__( self.best_model_score = None self.best_model_path = "" self.last_model_path = "" + self._save_on_train_end = save_on_train_end self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) @@ -231,9 +233,7 @@ def on_train_batch_end( self.save_checkpoint(trainer) def on_validation_end(self, trainer, pl_module) -> None: - """ - checkpoints can be saved at the end of the val loop - """ + """ Save a checkpoint at the end of the validation stage. """ skip = ( self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 @@ -242,6 +242,33 @@ def on_validation_end(self, trainer, pl_module) -> None: return self.save_checkpoint(trainer) + def on_train_end(self, trainer, pl_module) -> None: + """Save a checkpoint at the very end of training. + + This will only save a checkpoint if `save_last` is also enabled + as the monitor metrics produced by training or validation steps or end of epochs + is not guaranteed to be available at this stage. + """ + if self._should_skip_saving_checkpoint(trainer) or not trainer.checkpoint_connector.has_trained: + return + + initial_save_last = self.save_last + if self._save_on_train_end and not self.save_last: + rank_zero_warn( + "Requested to save a checkpoint at the end of training but save_last is not set. Temporarily setting save_last=True to save." + ) + self.save_last = True + if self.verbose: + rank_zero_info("Saving last checkpoint...") + + # as we advance one step at end of training, we use global_step - 1 + # to avoid saving duplicates + trainer.global_step -= 1 + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + trainer.global_step += 1 + self.save_last = initial_save_last + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f96c17a0686ce..3883bc3706682 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -97,12 +97,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") From 0e7b0ad9cdb249988d163764e074f7cb5f26ca88 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 16:00:08 -0700 Subject: [PATCH 22/22] update again --- .../callbacks/model_checkpoint.py | 29 ------------------- pytorch_lightning/trainer/training_loop.py | 6 ++++ tests/trainer/test_dataloaders.py | 20 ++++++++----- 3 files changed, 18 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f63fba7151f55..82839007b6851 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -188,7 +188,6 @@ def __init__( auto_insert_metric_name: bool = True, every_n_train_steps: Optional[int] = None, every_n_val_epochs: Optional[int] = None, - save_on_train_end: bool = False, period: Optional[int] = None, ): super().__init__() @@ -205,7 +204,6 @@ def __init__( self.best_model_score = None self.best_model_path = "" self.last_model_path = "" - self._save_on_train_end = save_on_train_end self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) @@ -242,33 +240,6 @@ def on_validation_end(self, trainer, pl_module) -> None: return self.save_checkpoint(trainer) - def on_train_end(self, trainer, pl_module) -> None: - """Save a checkpoint at the very end of training. - - This will only save a checkpoint if `save_last` is also enabled - as the monitor metrics produced by training or validation steps or end of epochs - is not guaranteed to be available at this stage. - """ - if self._should_skip_saving_checkpoint(trainer) or not trainer.checkpoint_connector.has_trained: - return - - initial_save_last = self.save_last - if self._save_on_train_end and not self.save_last: - rank_zero_warn( - "Requested to save a checkpoint at the end of training but save_last is not set. Temporarily setting save_last=True to save." - ) - self.save_last = True - if self.verbose: - rank_zero_info("Saving last checkpoint...") - - # as we advance one step at end of training, we use global_step - 1 - # to avoid saving duplicates - trainer.global_step -= 1 - monitor_candidates = self._monitor_candidates(trainer) - self._save_last_checkpoint(trainer, monitor_candidates) - trainer.global_step += 1 - self.save_last = initial_save_last - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3883bc3706682..f96c17a0686ce 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -97,6 +97,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 4774380e9d4f8..6b0ea97d41a70 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -351,14 +351,18 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch assert epoch_cb.train_batches_seen == limit_train_batches * epochs -@pytest.mark.parametrize(['dataset', 'limit_val_batches'], [ - (RandomDataset(32, 128), 0), - (RandomDataset(32, 128), 10), - (RandomIterableDataset(32, 128), 0), - (RandomIterableDataset(32, 128), 10), - (RandomIterableDatasetWithLen(32, 128), 0), - (RandomIterableDatasetWithLen(32, 128), 10), -]) +@pytest.mark.parametrize( + ['dataset', 'limit_val_batches'], + [ + (RandomDataset(32, 128), 0), + (RandomDataset(32, 128), 10), + (RandomIterableDataset(32, 128), 0), + (RandomIterableDataset(32, 128), 10), + (RandomIterableDatasetWithLen(32, 128), 0), + # TODO: enable this after #6671 is merged + # (RandomIterableDatasetWithLen(32, 128), 10), + ] +) def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""