1414import torch
1515
1616from pytorch_lightning import Trainer
17- from tests .base . develop_utils import load_model_from_checkpoint , get_default_logger , \
18- reset_seed
17+ from tests .base import BoringModel
18+ from tests . base . develop_utils import get_default_logger , load_model_from_checkpoint , reset_seed
1919
2020
2121def run_model_test_without_loggers (trainer_options , model , min_acc : float = 0.50 ):
@@ -31,6 +31,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50
3131 pretrained_model = load_model_from_checkpoint (
3232 trainer .logger ,
3333 trainer .checkpoint_callback .best_model_path ,
34+ type (model )
3435 )
3536
3637 # test new model accuracy
@@ -39,15 +40,16 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50
3940 test_loaders = [test_loaders ]
4041
4142 for dataloader in test_loaders :
42- run_prediction (dataloader , pretrained_model , min_acc = min_acc )
43+ run_prediction (pretrained_model , dataloader , min_acc = min_acc )
4344
4445 if trainer .use_ddp :
4546 # on hpc this would work fine... but need to hack it for the purpose of the test
4647 trainer .model = pretrained_model
4748 trainer .optimizers , trainer .lr_schedulers = pretrained_model .configure_optimizers ()
4849
4950
50- def run_model_test (trainer_options , model , on_gpu : bool = True , version = None , with_hpc : bool = True ):
51+ def run_model_test (trainer_options , model , on_gpu : bool = True , version = None ,
52+ with_hpc : bool = True , min_acc : float = 0.25 ):
5153
5254 reset_seed ()
5355 save_dir = trainer_options ['default_root_dir' ]
@@ -56,35 +58,34 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
5658 logger = get_default_logger (save_dir , version = version )
5759 trainer_options .update (logger = logger )
5860
59- if 'checkpoint_callback' not in trainer_options :
60- trainer_options .update (checkpoint_callback = True )
61-
6261 trainer = Trainer (** trainer_options )
6362 initial_values = torch .tensor ([torch .sum (torch .abs (x )) for x in model .parameters ()])
6463 result = trainer .fit (model )
6564 post_train_values = torch .tensor ([torch .sum (torch .abs (x )) for x in model .parameters ()])
6665
6766 assert result == 1 , 'trainer failed'
6867 # Check that the model is actually changed post-training
69- assert torch .norm (initial_values - post_train_values ) > 0.1
68+ change_ratio = torch .norm (initial_values - post_train_values )
69+ assert change_ratio > 0.1 , f"the model is changed of { change_ratio } "
7070
7171 # test model loading
72- pretrained_model = load_model_from_checkpoint (logger , trainer .checkpoint_callback .best_model_path )
72+ pretrained_model = load_model_from_checkpoint (logger , trainer .checkpoint_callback .best_model_path , type ( model ) )
7373
7474 # test new model accuracy
7575 test_loaders = model .test_dataloader ()
7676 if not isinstance (test_loaders , list ):
7777 test_loaders = [test_loaders ]
7878
7979 for dataloader in test_loaders :
80- run_prediction (dataloader , pretrained_model )
80+ run_prediction (pretrained_model , dataloader , min_acc = min_acc )
8181
8282 if with_hpc :
8383 if trainer .use_ddp or trainer .use_ddp2 :
8484 # on hpc this would work fine... but need to hack it for the purpose of the test
8585 trainer .model = pretrained_model
86- trainer .optimizers , trainer .lr_schedulers , trainer .optimizer_frequencies = \
87- trainer .init_optimizers (pretrained_model )
86+ trainer .optimizers , trainer .lr_schedulers , trainer .optimizer_frequencies = trainer .init_optimizers (
87+ pretrained_model
88+ )
8889
8990 # test HPC saving
9091 trainer .checkpoint_connector .hpc_save (save_dir , logger )
@@ -93,7 +94,14 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
9394 trainer .checkpoint_connector .hpc_load (checkpoint_path , on_gpu = on_gpu )
9495
9596
96- def run_prediction (dataloader , trained_model , dp = False , min_acc = 0.50 ):
97+ def run_prediction (trained_model , dataloader , dp = False , min_acc = 0.25 ):
98+ if isinstance (trained_model , BoringModel ):
99+ return _boring_model_run_prediction (trained_model , dataloader , dp , min_acc )
100+ else :
101+ return _eval_model_template_run_prediction (trained_model , dataloader , dp , min_acc )
102+
103+
104+ def _eval_model_template_run_prediction (trained_model , dataloader , dp = False , min_acc = 0.50 ):
97105 # run prediction on 1 batch
98106 batch = next (iter (dataloader ))
99107 x , y = batch
@@ -102,7 +110,7 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
102110 if dp :
103111 with torch .no_grad ():
104112 output = trained_model (batch , 0 )
105- acc = output ['val_acc' ]
113+ acc = output ['val_acc' ]
106114 acc = torch .mean (acc ).item ()
107115
108116 else :
@@ -119,3 +127,13 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
119127 acc = acc .item ()
120128
121129 assert acc >= min_acc , f"This model is expected to get > { min_acc } in test set (it got { acc } )"
130+
131+
132+ def _boring_model_run_prediction (trained_model , dataloader , dp = False , min_acc = 0.25 ):
133+ # run prediction on 1 batch
134+ batch = next (iter (dataloader ))
135+ with torch .no_grad ():
136+ output = trained_model (batch )
137+ acc = trained_model .loss (batch , output )
138+
139+ assert acc >= min_acc , f"This model is expected to get, { min_acc } in test set but got { acc } "
0 commit comments