Skip to content

Commit 74d79e7

Browse files
authored
Raise an exception if check_val_every_n_epoch is not an integer (#6411)
* raise an exception if check_val_every_n_epoch is not an integer * remove unused object * add type hints * add return type * update exception message * update exception message
1 parent 615b2f7 commit 74d79e7

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@ class DataConnector(object):
2626
def __init__(self, trainer):
2727
self.trainer = trainer
2828

29-
def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node):
29+
def on_trainer_init(
30+
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
31+
) -> None:
3032
self.trainer.datamodule = None
3133
self.trainer.prepare_data_per_node = prepare_data_per_node
3234

35+
if not isinstance(check_val_every_n_epoch, int):
36+
raise MisconfigurationException(
37+
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
38+
)
39+
3340
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
3441
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
3542
self.trainer._is_data_prepared = False

tests/trainer/test_trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,3 +1828,13 @@ def compare_optimizers():
18281828
trainer.max_epochs = 2 # simulate multiple fit calls
18291829
trainer.fit(model)
18301830
compare_optimizers()
1831+
1832+
1833+
def test_check_val_every_n_epoch_exception(tmpdir):
1834+
1835+
with pytest.raises(MisconfigurationException, match="should be an integer."):
1836+
Trainer(
1837+
default_root_dir=tmpdir,
1838+
max_epochs=1,
1839+
check_val_every_n_epoch=1.2,
1840+
)

0 commit comments

Comments
 (0)