From b05b74d458ff86587deb1381e5258a7f8d8642ab Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Fri, 18 Feb 2022 18:27:27 +0100 Subject: [PATCH 01/13] support val_check_interval values higher than number of training batches --- CHANGELOG.md | 3 + pytorch_lightning/callbacks/progress/base.py | 6 +- .../loops/epoch/training_epoch_loop.py | 11 ++- .../trainer/connectors/data_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 17 ++-- tests/loops/test_training_loop.py | 98 ++++++++++++++++++- 6 files changed, 127 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 267895c407f25..acd1da7fd7d38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Support setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#8135](https://github.com/PyTorchLightning/pytorch-lightning/issues/8135)) + + - Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 291fb495a81c9..9dbd4015da36d 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -110,7 +110,11 @@ def total_val_batches(self) -> Union[int, float]: total_val_batches = 0 if self.trainer.enable_validation: - is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + is_val_epoch = ( + True + if self.trainer.check_val_every_n_epoch is None + else (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + ) total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c8eefedd3c327..258cba7166715 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -480,7 +480,10 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if not self.trainer.enable_validation: return False - is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + is_val_check_epoch = ( + self.trainer.check_val_every_n_epoch is None + or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + ) if not is_val_check_epoch: return False @@ -492,7 +495,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if self.trainer.should_stop: return True - # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch + if self.trainer.check_val_every_n_epoch is None: + return (self.trainer.global_step + 1) % self.trainer.val_check_batch == 0 + + # TODO(@awaelchli): let training/eval loop handle logic around limit_*_ba + # tches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index ef79bd88db822..0bb18d17f4778 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -66,8 +66,8 @@ def _should_reload_val_dl(self) -> bool: def on_trainer_init( self, - check_val_every_n_epoch: int, reload_dataloaders_every_n_epochs: int, + check_val_every_n_epoch: Optional[int] = None, prepare_data_per_node: Optional[bool] = None, ) -> None: self.trainer.datamodule = None @@ -80,7 +80,7 @@ def on_trainer_init( ) self.trainer.prepare_data_per_node = prepare_data_per_node - if not isinstance(check_val_every_n_epoch, int): + if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int): raise MisconfigurationException( f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6ed5d6c31f719..93d4cb76d32c6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -149,7 +149,7 @@ def __init__( enable_progress_bar: bool = True, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, - check_val_every_n_epoch: int = 1, + check_val_every_n_epoch: Optional[int] = 1, fast_dev_run: Union[int, bool] = False, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, max_epochs: Optional[int] = None, @@ -239,7 +239,8 @@ def __init__( It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. - check_val_every_n_epoch: Check val every n train epochs. + check_val_every_n_epoch: Check val every n train epochs. If `None`, validation will be done based on + `val_check_interval`, and potentially exceed the number of batches in the training set. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. @@ -391,6 +392,8 @@ def __init__( val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). + This value can only be higher than the amount of batches in the data loader when + `check_val_every_n_epoch=None`, otherwise no validation is done. enable_model_summary: Whether to enable model summarization by default. @@ -513,8 +516,8 @@ def __init__( # init data flags self._data_connector.on_trainer_init( - check_val_every_n_epoch, reload_dataloaders_every_n_epochs, + check_val_every_n_epoch, prepare_data_per_node, ) @@ -1133,7 +1136,7 @@ def _run( # ---------------------------- # INSPECT THE CORE LOOPS # ---------------------------- - fr""" + rf""" Lightning internal flow looks like this: {Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || @@ -1822,11 +1825,13 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval - if self.val_check_batch > self.num_training_batches: + if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None: raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " - "If you want to disable validation set `limit_val_batches` to 0.0 instead." + "If you want to disable validation set `limit_val_batches` to 0.0 instead. " + "If you want to validate based on the step count instead of the epoch count, " + "set `check_val_every_n_epoch=None`." ) else: if not has_len_all_ranks(self.train_dataloader, self.strategy, module): diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 8329389a93944..fcc30e930ed76 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset def test_outputs_format(tmpdir): @@ -142,3 +143,98 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) + + +def test_validation_check_interval_exceed_data_length_correct(tmpdir): + batch_size = 32 + data_samples_train = 10 + data_samples_val = 1 + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called_at_step = set() + + def training_step(self, batch, batch_idx): + return super().training_step(batch, batch_idx) + + def validation_step(self, *args): + self.validation_called_at_step.add(int(self.trainer.global_step)) + return super().validation_step(*args) + + def train_dataloader(self): + return DataLoader(RandomDataset(batch_size, data_samples_train)) + + def val_dataloader(self): + return DataLoader(RandomDataset(batch_size, data_samples_val)) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=data_samples_train * 3, + val_check_interval=15, + check_val_every_n_epoch=None, + num_sanity_val_steps=0, + ) + + trainer.fit(model) + + # with a data length of 10, a val_check_interval of 15, and max_steps=30, we + # should have validated twice + assert trainer.current_epoch == 3 + assert trainer.global_step == 30 + assert list(model.validation_called_at_step) == [14, 29] + + +def test_validation_check_interval_exceed_data_length_wrong(tmpdir): + model = BoringModel() + + with pytest.raises(ValueError): + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=200, + val_check_interval=100, + check_val_every_n_epoch=1, + num_sanity_val_steps=0, + ) + trainer.fit(model) + + +def test_validation_loop_every_5_epochs(tmpdir): + batch_size = 32 + data_samples_train = 10 + data_samples_val = 1 + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called_at_step = set() + + def training_step(self, batch, batch_idx): + return super().training_step(batch, batch_idx) + + def validation_step(self, *args): + self.validation_called_at_step.add(int(self.trainer.global_step)) + return super().validation_step(*args) + + def train_dataloader(self): + return DataLoader(RandomDataset(batch_size, data_samples_train)) + + def val_dataloader(self): + return DataLoader(RandomDataset(batch_size, data_samples_val)) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=data_samples_train * 9, + check_val_every_n_epoch=5, + num_sanity_val_steps=0, + ) + + trainer.fit(model) + + # with a data length of 10, a val_check_interval of 15, and max_steps=30, we + # should have validated twice + assert trainer.current_epoch == 9 + assert trainer.global_step == 90 + assert list(model.validation_called_at_step) == [49] From 7039873058e1c846eea22fe87141f4d7098aa28a Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Fri, 18 Feb 2022 19:30:58 +0100 Subject: [PATCH 02/13] tests pass locally --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 +-- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/loops/test_training_loop.py | 2 +- tests/trainer/test_supporters.py | 3 +++ tests/trainer/test_trainer.py | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 258cba7166715..22b1d0e1fc327 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -498,8 +498,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if self.trainer.check_val_every_n_epoch is None: return (self.trainer.global_step + 1) % self.trainer.val_check_batch == 0 - # TODO(@awaelchli): let training/eval loop handle logic around limit_*_ba - # tches and val_check_batch + # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 93d4cb76d32c6..f34836d58adf6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -239,8 +239,8 @@ def __init__( It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. - check_val_every_n_epoch: Check val every n train epochs. If `None`, validation will be done based on - `val_check_interval`, and potentially exceed the number of batches in the training set. + check_val_every_n_epoch: Check val every n train epochs. If `None`, validation will be done solely based + on the number of steps, requiring `val_check_interval` to be an integer value. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index fcc30e930ed76..71ad7b1e33334 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -183,7 +183,7 @@ def val_dataloader(self): # should have validated twice assert trainer.current_epoch == 3 assert trainer.global_step == 30 - assert list(model.validation_called_at_step) == [14, 29] + assert sorted(list(model.validation_called_at_step)) == [14, 29] def test_validation_check_interval_exceed_data_length_wrong(tmpdir): diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 7088432e3b9ec..139a54ff3bcd8 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -36,6 +36,7 @@ from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.runif import RunIf def test_tensor_running_accum_reset(): @@ -381,6 +382,7 @@ def _assert_dataset(loader): apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) +@RunIf(min_gpus=2) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader @@ -434,6 +436,7 @@ def __iter__(self): assert get_len(dataloader) == float("inf") +@RunIf(min_gpus=2) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) @pytest.mark.parametrize("is_min_size_mode", [False, True]) @pytest.mark.parametrize("use_combined_loader", [False, True]) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0d2d8bbdc55b6..69228005e5d09 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1512,7 +1512,7 @@ def test_trainer_predict_1_gpu(tmpdir): predict(tmpdir, accelerator="gpu", devices=1) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, min_gpus=2) def test_trainer_predict_ddp_spawn(tmpdir): predict(tmpdir, strategy="ddp_spawn", accelerator="auto", devices=2) From 9b332dcbb11ddd1bc103aa15b750198a640a4416 Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Fri, 18 Feb 2022 19:58:16 +0100 Subject: [PATCH 03/13] add a test case for when check_val_interval is float --- .../trainer/connectors/data_connector.py | 7 +++++++ pytorch_lightning/trainer/trainer.py | 1 + tests/loops/test_training_loop.py | 15 +++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 0bb18d17f4778..fc1d6b7555dce 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -67,6 +67,7 @@ def _should_reload_val_dl(self) -> bool: def on_trainer_init( self, reload_dataloaders_every_n_epochs: int, + val_check_interval: Union[int, float], check_val_every_n_epoch: Optional[int] = None, prepare_data_per_node: Optional[bool] = None, ) -> None: @@ -85,6 +86,12 @@ def on_trainer_init( f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" ) + if check_val_every_n_epoch is None and isinstance(val_check_interval, float): + raise MisconfigurationException( + f"val_check_interval should be an integer when check_val_every_n_epoch={check_val_every_n_epoch}. " + f"Found val_check_interval={val_check_interval}" + ) + self.trainer.check_val_every_n_epoch = check_val_every_n_epoch if not isinstance(reload_dataloaders_every_n_epochs, int) or (reload_dataloaders_every_n_epochs < 0): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f34836d58adf6..7753e86aa220c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -517,6 +517,7 @@ def __init__( # init data flags self._data_connector.on_trainer_init( reload_dataloaders_every_n_epochs, + val_check_interval, check_val_every_n_epoch, prepare_data_per_node, ) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 71ad7b1e33334..a602d03614ce1 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -200,6 +201,20 @@ def test_validation_check_interval_exceed_data_length_wrong(tmpdir): trainer.fit(model) +def test_validation_check_interval_float_wrong(tmpdir): + model = BoringModel() + + with pytest.raises(MisconfigurationException): + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=200, + val_check_interval=0.5, + check_val_every_n_epoch=None, + num_sanity_val_steps=0, + ) + trainer.fit(model) + + def test_validation_loop_every_5_epochs(tmpdir): batch_size = 32 data_samples_train = 10 From 8bf975a019fd2034d8e021031259d4f898f05c7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Feb 2022 12:13:47 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c3ecf21b41d50..99db5134fb6d7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -248,7 +248,7 @@ def __init__( Default: ``True``. check_val_every_n_epoch: Check val every n train epochs. If `None`, validation will be done solely based - on the number of steps, requiring `val_check_interval` to be an integer value. + on the number of steps, requiring `val_check_interval` to be an integer value. Default: ``1``. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. From 538be35645951ee104d5730ca47d2eb1fd00d1ee Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Wed, 6 Apr 2022 11:37:39 +0200 Subject: [PATCH 05/13] Fix error in merging CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18773089b458c..bfe4a311f9201 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Allow logging to an existing run ID in MLflow with `MLFlowLogger` ([#12290](https://github.com/PyTorchLightning/pytorch-lightning/pull/12290)) - Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911)) - Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs ([#11008](https://github.com/PyTorchLightning/pytorch-lightning/pull/11008)) - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/pull/10601)) From e0c568d91b31d0f3fdc402274064156449055c6e Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Wed, 6 Apr 2022 12:32:11 +0200 Subject: [PATCH 06/13] fix some remaining errors from merge --- .../loops/epoch/training_epoch_loop.py | 14 +++++++++++--- .../trainer/connectors/data_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- tests/loops/test_training_loop.py | 7 ++++--- tests/trainer/test_supporters.py | 3 --- tests/trainer/test_trainer.py | 2 +- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3ad58c7fc67c2..ff76dc8f9bcae 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -501,11 +501,12 @@ def _get_monitor_value(self, key: str) -> Any: return self.trainer.callback_metrics.get(key) def _should_check_val_epoch(self): - if self.trainer.enable_validation: + if not self.trainer.enable_validation: return False # first we check if `check_val_every_n_epoch is `None`, which means - # that we run a validation loop after n steps (based on `val_check_interval`) + # that we run a validation loop after n global steps (taken from the + # Trainer argument `val_check_interval`) if self.trainer.check_val_every_n_epoch is None: return (self.trainer.global_step + 1) % self.trainer.val_check_batch == 0 @@ -531,7 +532,14 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + # if we're checking based on global step, we can start validation + # at any point in the training epoch + if self.trainer.check_val_every_n_epoch is None: + is_val_check_batch = True + else: + # TODO: clarify the purpose of this check. + is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + return is_val_check_batch def _save_loggers_on_train_batch_end(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 766a9cc1665f9..6dfc451d10d76 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -71,8 +71,8 @@ def _should_reload_val_dl(self) -> bool: def on_trainer_init( self, - reload_dataloaders_every_n_epochs: int, val_check_interval: Union[int, float], + reload_dataloaders_every_n_epochs: int, check_val_every_n_epoch: Optional[int] = None, ) -> None: self.trainer.datamodule = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 28dd456aebb4f..f1a4080836a1f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -532,8 +532,8 @@ def __init__( # init data flags self.check_val_every_n_epoch: int self._data_connector.on_trainer_init( - reload_dataloaders_every_n_epochs, val_check_interval, + reload_dataloaders_every_n_epochs, check_val_every_n_epoch, ) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 02829f7f19187..7aa206d21f20f 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -187,6 +187,7 @@ def val_dataloader(self): num_sanity_val_steps=0, ) + print("\ncalling trainer.fit") trainer.fit(model) # with a data length of 10, a val_check_interval of 15, and max_steps=30, we @@ -257,8 +258,8 @@ def val_dataloader(self): trainer.fit(model) - # with a data length of 10, a val_check_interval of 15, and max_steps=30, we - # should have validated twice + # with a data length of 10, validation every 5 epochs, and max_steps=90, we should + # validate once assert trainer.current_epoch == 9 assert trainer.global_step == 90 - assert list(model.validation_called_at_step) == [49] + assert list(model.validation_called_at_step) == [50] diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 139a54ff3bcd8..7088432e3b9ec 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -36,7 +36,6 @@ from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset -from tests.helpers.runif import RunIf def test_tensor_running_accum_reset(): @@ -382,7 +381,6 @@ def _assert_dataset(loader): apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) -@RunIf(min_gpus=2) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader @@ -436,7 +434,6 @@ def __iter__(self): assert get_len(dataloader) == float("inf") -@RunIf(min_gpus=2) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) @pytest.mark.parametrize("is_min_size_mode", [False, True]) @pytest.mark.parametrize("use_combined_loader", [False, True]) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 726289bfb8c33..7ac94538356fa 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1465,7 +1465,7 @@ def test_trainer_predict_1_gpu(tmpdir): predict(tmpdir, accelerator="gpu", devices=1) -@RunIf(skip_windows=True, min_gpus=2) +@RunIf(skip_windows=True) def test_trainer_predict_ddp_spawn(tmpdir): predict(tmpdir, strategy="ddp_spawn", accelerator="auto", devices=2) From 1be3ad98e11d56aba30b6b763d96472e8d7415a5 Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Fri, 8 Apr 2022 18:37:20 +0200 Subject: [PATCH 07/13] resolve comments --- .../loops/epoch/training_epoch_loop.py | 25 ++++++++----------- .../trainer/connectors/data_connector.py | 6 ++--- tests/loops/test_training_loop.py | 3 +-- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ff76dc8f9bcae..64b68c8cceda6 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -501,18 +501,10 @@ def _get_monitor_value(self, key: str) -> Any: return self.trainer.callback_metrics.get(key) def _should_check_val_epoch(self): - if not self.trainer.enable_validation: - return False - - # first we check if `check_val_every_n_epoch is `None`, which means - # that we run a validation loop after n global steps (taken from the - # Trainer argument `val_check_interval`) - if self.trainer.check_val_every_n_epoch is None: - return (self.trainer.global_step + 1) % self.trainer.val_check_batch == 0 - - # If it's not `None`, we respect running a validation loop after every n epochs - else: - return (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + return self.trainer.enable_validation and ( + self.trainer.check_val_every_n_epoch is None + or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + ) def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """Decide if we should run validation.""" @@ -532,10 +524,13 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): - # if we're checking based on global step, we can start validation - # at any point in the training epoch + # first we check if `check_val_every_n_epoch is `None`, which means + # that we run a validation loop after n global steps (n is taken from the + # Trainer argument `val_check_interval`) if self.trainer.check_val_every_n_epoch is None: - is_val_check_batch = True + is_val_check_batch = self.trainer.global_step % self.trainer.val_check_batch == 0 + + # If it's not `None`, we respect running a validation loop after every n epochs else: # TODO: clarify the purpose of this check. is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 6dfc451d10d76..a37319cee03bd 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -73,7 +73,7 @@ def on_trainer_init( self, val_check_interval: Union[int, float], reload_dataloaders_every_n_epochs: int, - check_val_every_n_epoch: Optional[int] = None, + check_val_every_n_epoch: Optional[int], ) -> None: self.trainer.datamodule = None @@ -84,8 +84,8 @@ def on_trainer_init( if check_val_every_n_epoch is None and isinstance(val_check_interval, float): raise MisconfigurationException( - f"val_check_interval should be an integer when check_val_every_n_epoch={check_val_every_n_epoch}. " - f"Found val_check_interval={val_check_interval}" + f"`Trainer(val_check_interval={val_check_interval!r})` should be an integer " + f"when `check_val_every_n_epoch={check_val_every_n_epoch}`. " ) self.trainer.check_val_every_n_epoch = check_val_every_n_epoch diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 7aa206d21f20f..4745bc55062df 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -187,14 +187,13 @@ def val_dataloader(self): num_sanity_val_steps=0, ) - print("\ncalling trainer.fit") trainer.fit(model) # with a data length of 10, a val_check_interval of 15, and max_steps=30, we # should have validated twice assert trainer.current_epoch == 3 assert trainer.global_step == 30 - assert sorted(list(model.validation_called_at_step)) == [14, 29] + assert sorted(list(model.validation_called_at_step)) == [15, 30] def test_validation_check_interval_exceed_data_length_wrong(tmpdir): From ed832c238da28251b8b875a999b26eada034df36 Mon Sep 17 00:00:00 2001 From: Nik Vaessen Date: Fri, 8 Apr 2022 19:24:47 +0200 Subject: [PATCH 08/13] check validation works with iterable dataset without length --- tests/loops/test_training_loop.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 4745bc55062df..6433526a54f8d 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -18,6 +18,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset +from tests.helpers.boring_model import RandomIterableDataset def test_outputs_format(tmpdir): @@ -155,11 +156,19 @@ def training_step_end(self, outputs): trainer.fit(model) -def test_validation_check_interval_exceed_data_length_correct(tmpdir): +@pytest.mark.parametrize("use_infinite_dataset", [True, False]) +def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset): batch_size = 32 data_samples_train = 10 data_samples_val = 1 + if use_infinite_dataset: + train_ds = RandomIterableDataset(size=batch_size, count=2_400_000_000) # approx inf + else: + train_ds = RandomDataset(size=batch_size, length=data_samples_train) + + val_ds = RandomDataset(batch_size, data_samples_val) + class TestModel(BoringModel): def __init__(self): super().__init__() @@ -173,10 +182,10 @@ def validation_step(self, *args): return super().validation_step(*args) def train_dataloader(self): - return DataLoader(RandomDataset(batch_size, data_samples_train)) + return DataLoader(train_ds) def val_dataloader(self): - return DataLoader(RandomDataset(batch_size, data_samples_val)) + return DataLoader(val_ds) model = TestModel() trainer = Trainer( @@ -189,9 +198,13 @@ def val_dataloader(self): trainer.fit(model) - # with a data length of 10, a val_check_interval of 15, and max_steps=30, we - # should have validated twice - assert trainer.current_epoch == 3 + # with a data length of 10 (or infinite), a val_check_interval of 15, and max_steps=30, + # we should have validated twice + if use_infinite_dataset: + assert trainer.current_epoch == 1 + else: + assert trainer.current_epoch == 3 + assert trainer.global_step == 30 assert sorted(list(model.validation_called_at_step)) == [15, 30] From 635cf70295a5b571736e4f7bf810ea5d05da9098 Mon Sep 17 00:00:00 2001 From: Nik Date: Tue, 12 Apr 2022 17:43:16 +0200 Subject: [PATCH 09/13] Update test_training_loop.py use more sensible value to approx inf --- tests/loops/test_training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 6433526a54f8d..7bfb0885fb792 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -163,7 +163,7 @@ def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infini data_samples_val = 1 if use_infinite_dataset: - train_ds = RandomIterableDataset(size=batch_size, count=2_400_000_000) # approx inf + train_ds = RandomIterableDataset(size=batch_size, count=10_000) # approx inf else: train_ds = RandomDataset(size=batch_size, length=data_samples_train) From b25c61b716b3d8da5ef123a63d8e93bde46f9c74 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 20 Apr 2022 16:50:04 +0530 Subject: [PATCH 10/13] self review and improve tests and warnings --- CHANGELOG.md | 2 +- .../loops/epoch/training_epoch_loop.py | 7 +- .../trainer/connectors/data_connector.py | 6 +- pytorch_lightning/trainer/trainer.py | 12 +- tests/loops/test_training_loop.py | 126 +----------------- .../flags/test_check_val_every_n_epoch.py | 39 +++++- .../trainer/flags/test_val_check_interval.py | 71 +++++++++- 7 files changed, 119 insertions(+), 144 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfe4a311f9201..091583f16c27a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Support setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#8135](https://github.com/PyTorchLightning/pytorch-lightning/issues/8135)) +- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993)) - diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 64b68c8cceda6..e82120dab0a03 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -524,13 +524,10 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): - # first we check if `check_val_every_n_epoch is `None`, which means - # that we run a validation loop after n global steps (n is taken from the - # Trainer argument `val_check_interval`) + # if `check_val_every_n_epoch is `None`, run a validation loop after n global steps if self.trainer.check_val_every_n_epoch is None: is_val_check_batch = self.trainer.global_step % self.trainer.val_check_batch == 0 - - # If it's not `None`, we respect running a validation loop after every n epochs + # else condition it based on the batch_idx of the current epoch else: # TODO: clarify the purpose of this check. is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index a37319cee03bd..978271f2a5d11 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -79,13 +79,13 @@ def on_trainer_init( if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int): raise MisconfigurationException( - f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" + f"`check_val_every_n_epoch` should be an integer, found {check_val_every_n_epoch!r}." ) if check_val_every_n_epoch is None and isinstance(val_check_interval, float): raise MisconfigurationException( - f"`Trainer(val_check_interval={val_check_interval!r})` should be an integer " - f"when `check_val_every_n_epoch={check_val_every_n_epoch}`. " + "`val_check_interval` should be an integer when `check_val_every_n_epoch=None`," + f" found {val_check_interval!r}." ) self.trainer.check_val_every_n_epoch = check_val_every_n_epoch diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f1a4080836a1f..9d62b77719464 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -247,8 +247,9 @@ def __init__( :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. - check_val_every_n_epoch: Perform a validation loop every after every `n` train epochs. If `None`, validation - will be done solely based on the number of steps, requiring `val_check_interval` to be an integer value. + check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``, + validation will be done solely based on the number of training steps, requiring ``val_check_interval`` + to be an integer value. Default: ``1``. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. @@ -409,7 +410,7 @@ def __init__( val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the amount of batches in the training set when - `check_val_every_n_epoch=None`, otherwise the validation set is never checked. + ``check_val_every_n_epoch=None``. Default: ``1.0``. enable_model_summary: Whether to enable model summarization by default. @@ -1889,9 +1890,8 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " - "If you want to disable validation set `limit_val_batches` to 0.0 instead. " - "If you want to validate based on the step count instead of the epoch count, " - "set `check_val_every_n_epoch=None`." + "If you want to disable validation set `limit_val_batches` to 0.0 instead." + "If you want to validate based on the total step count, set `check_val_every_n_epoch=None`." ) else: if not has_len_all_ranks(self.train_dataloader, self.strategy, module): diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 7bfb0885fb792..3de02d5f8bb1c 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -13,12 +13,9 @@ # limitations under the License. import pytest import torch -from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset -from tests.helpers.boring_model import RandomIterableDataset +from tests.helpers import BoringModel def test_outputs_format(tmpdir): @@ -154,124 +151,3 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) - - -@pytest.mark.parametrize("use_infinite_dataset", [True, False]) -def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset): - batch_size = 32 - data_samples_train = 10 - data_samples_val = 1 - - if use_infinite_dataset: - train_ds = RandomIterableDataset(size=batch_size, count=10_000) # approx inf - else: - train_ds = RandomDataset(size=batch_size, length=data_samples_train) - - val_ds = RandomDataset(batch_size, data_samples_val) - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.validation_called_at_step = set() - - def training_step(self, batch, batch_idx): - return super().training_step(batch, batch_idx) - - def validation_step(self, *args): - self.validation_called_at_step.add(int(self.trainer.global_step)) - return super().validation_step(*args) - - def train_dataloader(self): - return DataLoader(train_ds) - - def val_dataloader(self): - return DataLoader(val_ds) - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=data_samples_train * 3, - val_check_interval=15, - check_val_every_n_epoch=None, - num_sanity_val_steps=0, - ) - - trainer.fit(model) - - # with a data length of 10 (or infinite), a val_check_interval of 15, and max_steps=30, - # we should have validated twice - if use_infinite_dataset: - assert trainer.current_epoch == 1 - else: - assert trainer.current_epoch == 3 - - assert trainer.global_step == 30 - assert sorted(list(model.validation_called_at_step)) == [15, 30] - - -def test_validation_check_interval_exceed_data_length_wrong(tmpdir): - model = BoringModel() - - with pytest.raises(ValueError): - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=200, - val_check_interval=100, - check_val_every_n_epoch=1, - num_sanity_val_steps=0, - ) - trainer.fit(model) - - -def test_validation_check_interval_float_wrong(tmpdir): - model = BoringModel() - - with pytest.raises(MisconfigurationException): - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=200, - val_check_interval=0.5, - check_val_every_n_epoch=None, - num_sanity_val_steps=0, - ) - trainer.fit(model) - - -def test_validation_loop_every_5_epochs(tmpdir): - batch_size = 32 - data_samples_train = 10 - data_samples_val = 1 - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.validation_called_at_step = set() - - def training_step(self, batch, batch_idx): - return super().training_step(batch, batch_idx) - - def validation_step(self, *args): - self.validation_called_at_step.add(int(self.trainer.global_step)) - return super().validation_step(*args) - - def train_dataloader(self): - return DataLoader(RandomDataset(batch_size, data_samples_train)) - - def val_dataloader(self): - return DataLoader(RandomDataset(batch_size, data_samples_val)) - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=data_samples_train * 9, - check_val_every_n_epoch=5, - num_sanity_val_steps=0, - ) - - trainer.fit(model) - - # with a data length of 10, validation every 5 epochs, and max_steps=90, we should - # validate once - assert trainer.current_epoch == 9 - assert trainer.global_step == 90 - assert list(model.validation_called_at_step) == [50] diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index 97c6ddf7803ab..f72ae674cce18 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from torch.utils.data import DataLoader -from pytorch_lightning.trainer import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.trainer.trainer import Trainer +from tests.helpers import BoringModel, RandomDataset @pytest.mark.parametrize( @@ -46,3 +47,37 @@ def on_validation_epoch_start(self) -> None: assert model.val_epoch_calls == expected_val_loop_calls assert model.val_batches == expected_val_batches + + +def test_check_val_every_n_epoch_with_max_steps(tmpdir): + data_samples_train = 2 + check_val_every_n_epoch = 3 + max_epochs = 4 + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called_at_step = set() + + def validation_step(self, *args): + self.validation_called_at_step.add(int(self.trainer.global_step)) + return super().validation_step(*args) + + def train_dataloader(self): + return DataLoader(RandomDataset(32, data_samples_train)) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=data_samples_train * max_epochs, + check_val_every_n_epoch=check_val_every_n_epoch, + num_sanity_val_steps=0, + ) + + trainer.fit(model) + + # with a data length of 10, validation every 5 epochs, and max_steps=90, we should + # validate once + assert trainer.current_epoch == max_epochs + assert trainer.global_step == max_epochs * data_samples_train + assert list(model.validation_called_at_step) == [data_samples_train * check_val_every_n_epoch] diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 685e104805daa..335957b510c88 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -14,9 +14,12 @@ import logging import pytest +from torch.utils.data import DataLoader -from pytorch_lightning.trainer import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.boring_model import RandomIterableDataset @pytest.mark.parametrize("max_epochs", [1, 2, 3]) @@ -57,3 +60,67 @@ def test_val_check_interval_info_message(caplog, value): with caplog.at_level(logging.INFO): Trainer() assert message not in caplog.text + + +@pytest.mark.parametrize("use_infinite_dataset", [True, False]) +def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset): + data_samples_train = 4 + max_epochs = 3 + max_steps = data_samples_train * max_epochs + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called_at_step = set() + + def validation_step(self, *args): + self.validation_called_at_step.add(int(self.trainer.global_step)) + return super().validation_step(*args) + + def train_dataloader(self): + if use_infinite_dataset: + train_ds = RandomIterableDataset(32, count=max_steps + 100) # approx inf + else: + train_ds = RandomDataset(32, length=data_samples_train) + return DataLoader(train_ds) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_val_batches=1, + max_steps=max_steps, + val_check_interval=3, + check_val_every_n_epoch=None, + num_sanity_val_steps=0, + ) + + trainer.fit(model) + + assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs + assert trainer.global_step == max_steps + + # with a data length of 10 (or infinite), a val_check_interval of 15, and max_steps=30, + # we should have validated twice + assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12] + + +def test_validation_check_interval_exceed_data_length_wrong(): + trainer = Trainer( + limit_train_batches=10, + val_check_interval=100, + ) + + with pytest.raises(ValueError, match="must be less than or equal to the number of the training batches"): + trainer.fit(BoringModel()) + + +def test_val_check_interval_float_with_none_check_val_every_n_epoch(): + """Test that an exception is raised with `val_check_interval` is set to float with + `check_val_every_n_epoch=None`""" + with pytest.raises( + MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`" + ): + Trainer( + val_check_interval=0.5, + check_val_every_n_epoch=None, + ) From c7e2f0326ead0356caf8ea9d7c06fa6ecde94b3c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 20 Apr 2022 17:12:33 +0530 Subject: [PATCH 11/13] condition on stepping batches --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 11 +++++------ pytorch_lightning/trainer/trainer.py | 6 +++--- tests/trainer/flags/test_val_check_interval.py | 9 +++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index e82120dab0a03..ec02a099f6cbb 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -524,13 +524,12 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): - # if `check_val_every_n_epoch is `None`, run a validation loop after n global steps - if self.trainer.check_val_every_n_epoch is None: - is_val_check_batch = self.trainer.global_step % self.trainer.val_check_batch == 0 + # if `check_val_every_n_epoch is `None`, run a validation loop every n training batches # else condition it based on the batch_idx of the current epoch - else: - # TODO: clarify the purpose of this check. - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + current_iteration = ( + self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx + ) + is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0 return is_val_check_batch diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2b61d5c0401d3..189281627e5b9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -243,7 +243,7 @@ def __init__( Default: ``True``. check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``, - validation will be done solely based on the number of training steps, requiring ``val_check_interval`` + validation will be done solely based on the number of training batches, requiring ``val_check_interval`` to be an integer value. Default: ``1``. @@ -404,7 +404,7 @@ def __init__( val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training - batches. An ``int`` value can only be higher than the amount of batches in the training set when + batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``. Default: ``1.0``. @@ -1837,7 +1837,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " "If you want to disable validation set `limit_val_batches` to 0.0 instead." - "If you want to validate based on the total step count, set `check_val_every_n_epoch=None`." + "If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`." ) else: if not has_len_all_ranks(self.train_dataloader, self.strategy, module): diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 335957b510c88..877aa180ecf4d 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -78,10 +78,11 @@ def validation_step(self, *args): return super().validation_step(*args) def train_dataloader(self): - if use_infinite_dataset: - train_ds = RandomIterableDataset(32, count=max_steps + 100) # approx inf - else: - train_ds = RandomDataset(32, length=data_samples_train) + train_ds = ( + RandomIterableDataset(32, count=max_steps + 100) + if use_infinite_dataset + else RandomDataset(32, length=data_samples_train) + ) return DataLoader(train_ds) model = TestModel() From 28e40787d137d818e888e1418341cd4d3e41b3df Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 20 Apr 2022 22:38:37 +0530 Subject: [PATCH 12/13] update tests --- tests/trainer/flags/test_check_val_every_n_epoch.py | 2 +- tests/trainer/flags/test_val_check_interval.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index f72ae674cce18..e8475842316ad 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -60,7 +60,7 @@ def __init__(self): self.validation_called_at_step = set() def validation_step(self, *args): - self.validation_called_at_step.add(int(self.trainer.global_step)) + self.validation_called_at_step.add(self.global_step) return super().validation_step(*args) def train_dataloader(self): diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 877aa180ecf4d..b4818a4a963a5 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -74,7 +74,7 @@ def __init__(self): self.validation_called_at_step = set() def validation_step(self, *args): - self.validation_called_at_step.add(int(self.trainer.global_step)) + self.validation_called_at_step.add(self.global_step) return super().validation_step(*args) def train_dataloader(self): @@ -111,8 +111,9 @@ def test_validation_check_interval_exceed_data_length_wrong(): val_check_interval=100, ) + model = BoringModel() with pytest.raises(ValueError, match="must be less than or equal to the number of the training batches"): - trainer.fit(BoringModel()) + trainer.fit(model) def test_val_check_interval_float_with_none_check_val_every_n_epoch(): From 6e4a5a6374ce90aa91dbab4ffa55113b3a5eca8b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 21 Apr 2022 14:26:51 +0530 Subject: [PATCH 13/13] cleanup --- tests/trainer/flags/test_check_val_every_n_epoch.py | 2 -- tests/trainer/flags/test_val_check_interval.py | 5 +---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index e8475842316ad..ca2537b829cd7 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -76,8 +76,6 @@ def train_dataloader(self): trainer.fit(model) - # with a data length of 10, validation every 5 epochs, and max_steps=90, we should - # validate once assert trainer.current_epoch == max_epochs assert trainer.global_step == max_epochs * data_samples_train assert list(model.validation_called_at_step) == [data_samples_train * check_val_every_n_epoch] diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index b4818a4a963a5..b575faa81203c 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -99,9 +99,6 @@ def train_dataloader(self): assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs assert trainer.global_step == max_steps - - # with a data length of 10 (or infinite), a val_check_interval of 15, and max_steps=30, - # we should have validated twice assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12] @@ -117,7 +114,7 @@ def test_validation_check_interval_exceed_data_length_wrong(): def test_val_check_interval_float_with_none_check_val_every_n_epoch(): - """Test that an exception is raised with `val_check_interval` is set to float with + """Test that an exception is raised when `val_check_interval` is set to float with `check_val_every_n_epoch=None`""" with pytest.raises( MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`"