@@ -32,50 +32,38 @@ def test_initialize_state(tmpdir):
3232)
3333def test_trainer_state_while_running (tmpdir , extra_params ):
3434 trainer = Trainer (default_root_dir = tmpdir , ** extra_params , auto_lr_find = True )
35- fdr = trainer .fast_dev_run
3635
3736 class TestModel (BoringModel ):
3837 def __init__ (self , expected_state ):
3938 super ().__init__ ()
4039 self .expected_state = expected_state
41- self .called = set ()
4240 self .lr = 0.1
4341
4442 def on_batch_start (self , * _ ):
4543 assert self .trainer .state == self .expected_state
4644
4745 def on_train_batch_start (self , * _ ):
48- self .called .add ("train" )
4946 assert self .trainer .training
5047
5148 def on_sanity_check_start (self , * _ ):
52- self .called .add ("sanity" )
5349 assert self .trainer .sanity_checking
5450
5551 def on_validation_batch_start (self , * _ ):
56- self .called .add ("validation" )
5752 assert self .trainer .validating or self .trainer .sanity_checking
5853
5954 def on_test_batch_start (self , * _ ):
60- self .called .add ("test" )
6155 assert self .trainer .testing
6256
6357 model = TestModel (TrainerState .TUNING )
6458 trainer .tune (model )
65- if fdr :
66- assert not model .called
67- else :
68- assert model .called == {'train' , 'validation' }
6959 assert trainer .state == TrainerState .FINISHED
7060
7161 model = TestModel (TrainerState .FITTING )
7262 trainer .fit (model )
73- assert model .called == {'train' , 'validation' } if fdr else {'train' , 'sanity' , 'validation' }
7463 assert trainer .state == TrainerState .FINISHED
7564
7665 model = TestModel (TrainerState .TESTING )
7766 trainer .test (model )
78- assert model .called == {'test' }
7967 assert trainer .state == TrainerState .FINISHED
8068
8169
0 commit comments