Skip to content

Commit a758d90

Browse files
nikvaessenrohitgr7
andauthored
Support val_check_interval values higher than number of training batches (#11993)
Co-authored-by: rohitgr7 <[email protected]>
1 parent f300b60 commit a758d90

File tree

6 files changed

+134
-19
lines changed

6 files changed

+134
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
13-
12+
- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993))
1413

1514
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))
1615

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,9 @@ def _get_monitor_value(self, key: str) -> Any:
501501
return self.trainer.callback_metrics.get(key)
502502

503503
def _should_check_val_epoch(self):
504-
return (
505-
self.trainer.enable_validation
506-
and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
504+
return self.trainer.enable_validation and (
505+
self.trainer.check_val_every_n_epoch is None
506+
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
507507
)
508508

509509
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
@@ -524,7 +524,13 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
524524
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
525525
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
526526
elif self.trainer.val_check_batch != float("inf"):
527-
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
527+
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
528+
# else condition it based on the batch_idx of the current epoch
529+
current_iteration = (
530+
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
531+
)
532+
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0
533+
528534
return is_val_check_batch
529535

530536
def _save_loggers_on_train_batch_end(self) -> None:

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,21 @@ def _should_reload_val_dl(self) -> bool:
7171

7272
def on_trainer_init(
7373
self,
74-
check_val_every_n_epoch: int,
74+
val_check_interval: Union[int, float],
7575
reload_dataloaders_every_n_epochs: int,
76+
check_val_every_n_epoch: Optional[int],
7677
) -> None:
7778
self.trainer.datamodule = None
7879

79-
if not isinstance(check_val_every_n_epoch, int):
80+
if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int):
8081
raise MisconfigurationException(
81-
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
82+
f"`check_val_every_n_epoch` should be an integer, found {check_val_every_n_epoch!r}."
83+
)
84+
85+
if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
86+
raise MisconfigurationException(
87+
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
88+
f" found {val_check_interval!r}."
8289
)
8390

8491
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
enable_progress_bar: bool = True,
146146
overfit_batches: Union[int, float] = 0.0,
147147
track_grad_norm: Union[int, float, str] = -1,
148-
check_val_every_n_epoch: int = 1,
148+
check_val_every_n_epoch: Optional[int] = 1,
149149
fast_dev_run: Union[int, bool] = False,
150150
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
151151
max_epochs: Optional[int] = None,
@@ -242,10 +242,11 @@ def __init__(
242242
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
243243
Default: ``True``.
244244
245-
check_val_every_n_epoch: Check val every n train epochs.
245+
check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``,
246+
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
247+
to be an integer value.
246248
Default: ``1``.
247249
248-
249250
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
250251
Default: ``os.getcwd()``.
251252
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
@@ -403,7 +404,8 @@ def __init__(
403404
404405
val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
405406
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
406-
batches.
407+
batches. An ``int`` value can only be higher than the number of training batches when
408+
``check_val_every_n_epoch=None``.
407409
Default: ``1.0``.
408410
409411
enable_model_summary: Whether to enable model summarization by default.
@@ -524,8 +526,9 @@ def __init__(
524526
# init data flags
525527
self.check_val_every_n_epoch: int
526528
self._data_connector.on_trainer_init(
527-
check_val_every_n_epoch,
529+
val_check_interval,
528530
reload_dataloaders_every_n_epochs,
531+
check_val_every_n_epoch,
529532
)
530533

531534
# gradient clipping
@@ -1829,11 +1832,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
18291832

18301833
if isinstance(self.val_check_interval, int):
18311834
self.val_check_batch = self.val_check_interval
1832-
if self.val_check_batch > self.num_training_batches:
1835+
if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None:
18331836
raise ValueError(
18341837
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
18351838
f"to the number of the training batches ({self.num_training_batches}). "
18361839
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
1840+
"If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
18371841
)
18381842
else:
18391843
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):

tests/trainer/flags/test_check_val_every_n_epoch.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytest
15+
from torch.utils.data import DataLoader
1516

16-
from pytorch_lightning.trainer import Trainer
17-
from tests.helpers import BoringModel
17+
from pytorch_lightning.trainer.trainer import Trainer
18+
from tests.helpers import BoringModel, RandomDataset
1819

1920

2021
@pytest.mark.parametrize(
@@ -46,3 +47,35 @@ def on_validation_epoch_start(self) -> None:
4647

4748
assert model.val_epoch_calls == expected_val_loop_calls
4849
assert model.val_batches == expected_val_batches
50+
51+
52+
def test_check_val_every_n_epoch_with_max_steps(tmpdir):
53+
data_samples_train = 2
54+
check_val_every_n_epoch = 3
55+
max_epochs = 4
56+
57+
class TestModel(BoringModel):
58+
def __init__(self):
59+
super().__init__()
60+
self.validation_called_at_step = set()
61+
62+
def validation_step(self, *args):
63+
self.validation_called_at_step.add(self.global_step)
64+
return super().validation_step(*args)
65+
66+
def train_dataloader(self):
67+
return DataLoader(RandomDataset(32, data_samples_train))
68+
69+
model = TestModel()
70+
trainer = Trainer(
71+
default_root_dir=tmpdir,
72+
max_steps=data_samples_train * max_epochs,
73+
check_val_every_n_epoch=check_val_every_n_epoch,
74+
num_sanity_val_steps=0,
75+
)
76+
77+
trainer.fit(model)
78+
79+
assert trainer.current_epoch == max_epochs
80+
assert trainer.global_step == max_epochs * data_samples_train
81+
assert list(model.validation_called_at_step) == [data_samples_train * check_val_every_n_epoch]

tests/trainer/flags/test_val_check_interval.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
import logging
1515

1616
import pytest
17+
from torch.utils.data import DataLoader
1718

18-
from pytorch_lightning.trainer import Trainer
19-
from tests.helpers import BoringModel
19+
from pytorch_lightning.trainer.trainer import Trainer
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
from tests.helpers import BoringModel, RandomDataset
22+
from tests.helpers.boring_model import RandomIterableDataset
2023

2124

2225
@pytest.mark.parametrize("max_epochs", [1, 2, 3])
@@ -57,3 +60,66 @@ def test_val_check_interval_info_message(caplog, value):
5760
with caplog.at_level(logging.INFO):
5861
Trainer()
5962
assert message not in caplog.text
63+
64+
65+
@pytest.mark.parametrize("use_infinite_dataset", [True, False])
66+
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset):
67+
data_samples_train = 4
68+
max_epochs = 3
69+
max_steps = data_samples_train * max_epochs
70+
71+
class TestModel(BoringModel):
72+
def __init__(self):
73+
super().__init__()
74+
self.validation_called_at_step = set()
75+
76+
def validation_step(self, *args):
77+
self.validation_called_at_step.add(self.global_step)
78+
return super().validation_step(*args)
79+
80+
def train_dataloader(self):
81+
train_ds = (
82+
RandomIterableDataset(32, count=max_steps + 100)
83+
if use_infinite_dataset
84+
else RandomDataset(32, length=data_samples_train)
85+
)
86+
return DataLoader(train_ds)
87+
88+
model = TestModel()
89+
trainer = Trainer(
90+
default_root_dir=tmpdir,
91+
limit_val_batches=1,
92+
max_steps=max_steps,
93+
val_check_interval=3,
94+
check_val_every_n_epoch=None,
95+
num_sanity_val_steps=0,
96+
)
97+
98+
trainer.fit(model)
99+
100+
assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs
101+
assert trainer.global_step == max_steps
102+
assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12]
103+
104+
105+
def test_validation_check_interval_exceed_data_length_wrong():
106+
trainer = Trainer(
107+
limit_train_batches=10,
108+
val_check_interval=100,
109+
)
110+
111+
model = BoringModel()
112+
with pytest.raises(ValueError, match="must be less than or equal to the number of the training batches"):
113+
trainer.fit(model)
114+
115+
116+
def test_val_check_interval_float_with_none_check_val_every_n_epoch():
117+
"""Test that an exception is raised when `val_check_interval` is set to float with
118+
`check_val_every_n_epoch=None`"""
119+
with pytest.raises(
120+
MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`"
121+
):
122+
Trainer(
123+
val_check_interval=0.5,
124+
check_val_every_n_epoch=None,
125+
)

0 commit comments

Comments
 (0)