@@ -115,20 +115,34 @@ def test_all_features_cpu_model(tmpdir):
115115
116116
117117def test_early_stopping_cpu_model (tmpdir ):
118- """Test each of the trainer options."""
118+ """Test each of the trainer options. Simply test the combo trainer and
119+ model; callbacks functionality tests are in /tests/callbacks"""
120+ class ModelTrainVal (BoringModel ):
121+ def __init__ (self , * args , ** kwargs ):
122+ super ().__init__ (* args , ** kwargs )
123+
124+ def validation_step (self , batch , batch_idx ):
125+ output = self .layer (batch )
126+ loss = self .loss (batch , output )
127+ return {"x" : loss }
128+
129+ def validation_epoch_end (self , outputs ) -> None :
130+ val_loss = torch .stack ([x ["x" ] for x in outputs ]).mean ()
131+ self .log ('val_loss' , val_loss )
132+
119133 stopping = EarlyStopping (monitor = "val_loss" , min_delta = 0.1 )
120134 trainer_options = dict (
121135 default_root_dir = tmpdir ,
122136 callbacks = [stopping ],
123137 max_epochs = 2 ,
124- gradient_clip_val = 1.0 ,
125- overfit_batches = 0.20 ,
138+ gradient_clip_val = 1 ,
126139 track_grad_norm = 2 ,
127- limit_train_batches = 0.1 ,
140+ limit_train_batches = 0.2 ,
128141 limit_val_batches = 0.1 ,
129142 )
130143
131- model = BoringModel ()
144+ model = ModelTrainVal ()
145+
132146 tpipes .run_model_test (trainer_options , model , on_gpu = False )
133147
134148 # test freeze on cpu
@@ -199,7 +213,29 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
199213
200214def test_running_test_after_fitting (tmpdir ):
201215 """Verify test() on fitted model."""
202- model = BoringModel ()
216+ class ModelTrainValTest (BoringModel ):
217+ def __init__ (self , * args , ** kwargs ):
218+ super ().__init__ (* args , ** kwargs )
219+
220+ def validation_step (self , batch , batch_idx ):
221+ output = self .layer (batch )
222+ loss = self .loss (batch , output )
223+ return {"x" : loss }
224+
225+ def validation_epoch_end (self , outputs ) -> None :
226+ val_loss = torch .stack ([x ["x" ] for x in outputs ]).mean ()
227+ self .log ('val_loss' , val_loss )
228+
229+ def test_step (self , batch , batch_idx ):
230+ output = self .layer (batch )
231+ loss = self .loss (batch , output )
232+ return {"y" : loss }
233+
234+ def test_epoch_end (self , outputs ) -> None :
235+ test_loss = torch .stack ([x ["y" ] for x in outputs ]).mean ()
236+ self .log ('test_loss' , test_loss )
237+
238+ model = ModelTrainValTest ()
203239
204240 # logger file to get meta
205241 logger = tutils .get_default_logger (tmpdir )
@@ -230,7 +266,20 @@ def test_running_test_after_fitting(tmpdir):
230266
231267def test_running_test_no_val (tmpdir ):
232268 """Verify `test()` works on a model with no `val_loader`."""
233- model = BoringModel ()
269+ class ModelTrainTest (BoringModel ):
270+ def __init__ (self , * args , ** kwargs ):
271+ super ().__init__ (* args , ** kwargs )
272+
273+ def test_step (self , batch , batch_idx ):
274+ output = self .layer (batch )
275+ loss = self .loss (batch , output )
276+ return {"y" : loss }
277+
278+ def test_epoch_end (self , outputs ) -> None :
279+ test_loss = torch .stack ([x ["y" ] for x in outputs ]).mean ()
280+ self .log ('test_loss' , test_loss )
281+
282+ model = ModelTrainTest ()
234283
235284 # logger file to get meta
236285 logger = tutils .get_default_logger (tmpdir )
0 commit comments