Skip to content

Commit c80e45d

Browse files
authored
Fix val_check_interval with fast_dev_run (#5540)
* fix val_check_interval with fast_dev_run * chlog
1 parent 6926b84 commit c80e45d

File tree

3 files changed

+45
-19
lines changed

3 files changed

+45
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519))
2828

2929

30+
- Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540))
31+
32+
3033
## [1.1.4] - 2021-01-12
3134

3235
### Added

pytorch_lightning/trainer/connectors/debugging_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def on_init_start(
5959
self.trainer.max_steps = fast_dev_run
6060
self.trainer.num_sanity_val_steps = 0
6161
self.trainer.max_epochs = 1
62-
self.trainer.val_check_interval = 1.0
62+
val_check_interval = 1.0
6363
self.trainer.check_val_every_n_epoch = 1
6464
self.trainer.logger = DummyLogger()
6565

tests/trainer/flags/test_fast_dev_run.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,59 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
3636
class FastDevRunModel(BoringModel):
3737
def __init__(self):
3838
super().__init__()
39-
self.training_step_called = False
40-
self.validation_step_called = False
41-
self.test_step_called = False
39+
self.training_step_call_count = 0
40+
self.training_epoch_end_call_count = 0
41+
self.validation_step_call_count = 0
42+
self.validation_epoch_end_call_count = 0
43+
self.test_step_call_count = 0
4244

4345
def training_step(self, batch, batch_idx):
4446
self.log('some_metric', torch.tensor(7.))
4547
self.logger.experiment.dummy_log('some_distribution', torch.randn(7) + batch_idx)
46-
self.training_step_called = True
48+
self.training_step_call_count += 1
4749
return super().training_step(batch, batch_idx)
4850

51+
def training_epoch_end(self, outputs):
52+
self.training_epoch_end_call_count += 1
53+
super().training_epoch_end(outputs)
54+
4955
def validation_step(self, batch, batch_idx):
50-
self.validation_step_called = True
56+
self.validation_step_call_count += 1
5157
return super().validation_step(batch, batch_idx)
5258

59+
def validation_epoch_end(self, outputs):
60+
self.validation_epoch_end_call_count += 1
61+
super().validation_epoch_end(outputs)
62+
63+
def test_step(self, batch, batch_idx):
64+
self.test_step_call_count += 1
65+
return super().test_step(batch, batch_idx)
66+
5367
checkpoint_callback = ModelCheckpoint()
5468
early_stopping_callback = EarlyStopping()
5569
trainer_config = dict(
5670
fast_dev_run=fast_dev_run,
71+
val_check_interval=2,
5772
logger=True,
5873
log_every_n_steps=1,
5974
callbacks=[checkpoint_callback, early_stopping_callback],
6075
)
6176

62-
def _make_fast_dev_run_assertions(trainer):
77+
def _make_fast_dev_run_assertions(trainer, model):
78+
# check the call count for train/val/test step/epoch
79+
assert model.training_step_call_count == fast_dev_run
80+
assert model.training_epoch_end_call_count == 1
81+
assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run
82+
assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1
83+
assert model.test_step_call_count == fast_dev_run
84+
85+
# check trainer arguments
86+
assert trainer.max_steps == fast_dev_run
87+
assert trainer.num_sanity_val_steps == 0
88+
assert trainer.max_epochs == 1
89+
assert trainer.val_check_interval == 1.0
90+
assert trainer.check_val_every_n_epoch == 1
91+
6392
# there should be no logger with fast_dev_run
6493
assert isinstance(trainer.logger, DummyLogger)
6594
assert len(trainer.dev_debugger.logged_metrics) == fast_dev_run
@@ -76,13 +105,10 @@ def _make_fast_dev_run_assertions(trainer):
76105
train_val_step_model = FastDevRunModel()
77106
trainer = Trainer(**trainer_config)
78107
results = trainer.fit(train_val_step_model)
79-
assert results
108+
trainer.test(ckpt_path=None)
80109

81-
# make sure both training_step and validation_step were called
82-
assert train_val_step_model.training_step_called
83-
assert train_val_step_model.validation_step_called
84-
85-
_make_fast_dev_run_assertions(trainer)
110+
assert results
111+
_make_fast_dev_run_assertions(trainer, train_val_step_model)
86112

87113
# -----------------------
88114
# also called once with no val step
@@ -92,10 +118,7 @@ def _make_fast_dev_run_assertions(trainer):
92118

93119
trainer = Trainer(**trainer_config)
94120
results = trainer.fit(train_step_only_model)
95-
assert results
121+
trainer.test(ckpt_path=None)
96122

97-
# make sure only training_step was called
98-
assert train_step_only_model.training_step_called
99-
assert not train_step_only_model.validation_step_called
100-
101-
_make_fast_dev_run_assertions(trainer)
123+
assert results
124+
_make_fast_dev_run_assertions(trainer, train_step_only_model)

0 commit comments

Comments
 (0)