@@ -36,30 +36,59 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
3636 class FastDevRunModel (BoringModel ):
3737 def __init__ (self ):
3838 super ().__init__ ()
39- self .training_step_called = False
40- self .validation_step_called = False
41- self .test_step_called = False
39+ self .training_step_call_count = 0
40+ self .training_epoch_end_call_count = 0
41+ self .validation_step_call_count = 0
42+ self .validation_epoch_end_call_count = 0
43+ self .test_step_call_count = 0
4244
4345 def training_step (self , batch , batch_idx ):
4446 self .log ('some_metric' , torch .tensor (7. ))
4547 self .logger .experiment .dummy_log ('some_distribution' , torch .randn (7 ) + batch_idx )
46- self .training_step_called = True
48+ self .training_step_call_count += 1
4749 return super ().training_step (batch , batch_idx )
4850
51+ def training_epoch_end (self , outputs ):
52+ self .training_epoch_end_call_count += 1
53+ super ().training_epoch_end (outputs )
54+
4955 def validation_step (self , batch , batch_idx ):
50- self .validation_step_called = True
56+ self .validation_step_call_count += 1
5157 return super ().validation_step (batch , batch_idx )
5258
59+ def validation_epoch_end (self , outputs ):
60+ self .validation_epoch_end_call_count += 1
61+ super ().validation_epoch_end (outputs )
62+
63+ def test_step (self , batch , batch_idx ):
64+ self .test_step_call_count += 1
65+ return super ().test_step (batch , batch_idx )
66+
5367 checkpoint_callback = ModelCheckpoint ()
5468 early_stopping_callback = EarlyStopping ()
5569 trainer_config = dict (
5670 fast_dev_run = fast_dev_run ,
71+ val_check_interval = 2 ,
5772 logger = True ,
5873 log_every_n_steps = 1 ,
5974 callbacks = [checkpoint_callback , early_stopping_callback ],
6075 )
6176
62- def _make_fast_dev_run_assertions (trainer ):
77+ def _make_fast_dev_run_assertions (trainer , model ):
78+ # check the call count for train/val/test step/epoch
79+ assert model .training_step_call_count == fast_dev_run
80+ assert model .training_epoch_end_call_count == 1
81+ assert model .validation_step_call_count == 0 if model .validation_step is None else fast_dev_run
82+ assert model .validation_epoch_end_call_count == 0 if model .validation_step is None else 1
83+ assert model .test_step_call_count == fast_dev_run
84+
85+ # check trainer arguments
86+ assert trainer .max_steps == fast_dev_run
87+ assert trainer .num_sanity_val_steps == 0
88+ assert trainer .max_epochs == 1
89+ assert trainer .val_check_interval == 1.0
90+ assert trainer .check_val_every_n_epoch == 1
91+
6392 # there should be no logger with fast_dev_run
6493 assert isinstance (trainer .logger , DummyLogger )
6594 assert len (trainer .dev_debugger .logged_metrics ) == fast_dev_run
@@ -76,13 +105,10 @@ def _make_fast_dev_run_assertions(trainer):
76105 train_val_step_model = FastDevRunModel ()
77106 trainer = Trainer (** trainer_config )
78107 results = trainer .fit (train_val_step_model )
79- assert results
108+ trainer . test ( ckpt_path = None )
80109
81- # make sure both training_step and validation_step were called
82- assert train_val_step_model .training_step_called
83- assert train_val_step_model .validation_step_called
84-
85- _make_fast_dev_run_assertions (trainer )
110+ assert results
111+ _make_fast_dev_run_assertions (trainer , train_val_step_model )
86112
87113 # -----------------------
88114 # also called once with no val step
@@ -92,10 +118,7 @@ def _make_fast_dev_run_assertions(trainer):
92118
93119 trainer = Trainer (** trainer_config )
94120 results = trainer .fit (train_step_only_model )
95- assert results
121+ trainer . test ( ckpt_path = None )
96122
97- # make sure only training_step was called
98- assert train_step_only_model .training_step_called
99- assert not train_step_only_model .validation_step_called
100-
101- _make_fast_dev_run_assertions (trainer )
123+ assert results
124+ _make_fast_dev_run_assertions (trainer , train_step_only_model )
0 commit comments