1313# limitations under the License.
1414import torch
1515
16- from pytorch_lightning import Trainer
16+ from pytorch_lightning import LightningDataModule , LightningModule , Trainer
17+ from pytorch_lightning .metrics .functional import accuracy
1718from pytorch_lightning .trainer .states import TrainerState
1819from pytorch_lightning .utilities import DistributedType
20+ from tests .base import EvalModelTemplate
1921from tests .helpers import BoringModel
2022from tests .helpers .utils import get_default_logger , load_model_from_checkpoint , reset_seed
2123
2224
23- def run_model_test_without_loggers (trainer_options , model , min_acc : float = 0.50 ):
25+ def run_model_test_without_loggers (
26+ trainer_options : dict , model : LightningModule , data : LightningDataModule = None , min_acc : float = 0.50
27+ ):
2428 reset_seed ()
2529
2630 # fit model
2731 trainer = Trainer (** trainer_options )
28- trainer .fit (model )
32+ trainer .fit (model , datamodule = data )
2933
3034 # correct result and ok accuracy
3135 assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
@@ -35,12 +39,13 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50
3539 )
3640
3741 # test new model accuracy
38- test_loaders = model .test_dataloader ()
42+ test_loaders = model .test_dataloader () if not data else data . test_dataloader ()
3943 if not isinstance (test_loaders , list ):
4044 test_loaders = [test_loaders ]
4145
42- for dataloader in test_loaders :
43- run_prediction (pretrained_model , dataloader , min_acc = min_acc )
46+ if not isinstance (model , BoringModel ):
47+ for dataloader in test_loaders :
48+ run_prediction_eval_model_template (model , dataloader , min_acc = min_acc )
4449
4550 if trainer ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN ):
4651 # on hpc this would work fine... but need to hack it for the purpose of the test
@@ -77,8 +82,9 @@ def run_model_test(
7782 if not isinstance (test_loaders , list ):
7883 test_loaders = [test_loaders ]
7984
80- for dataloader in test_loaders :
81- run_prediction (pretrained_model , dataloader , min_acc = min_acc )
85+ if not isinstance (model , BoringModel ):
86+ for dataloader in test_loaders :
87+ run_prediction_eval_model_template (model , dataloader , min_acc = min_acc )
8288
8389 if with_hpc :
8490 if trainer ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN , DistributedType .DDP2 ):
@@ -95,14 +101,7 @@ def run_model_test(
95101 trainer .checkpoint_connector .hpc_load (checkpoint_path , on_gpu = on_gpu )
96102
97103
98- def run_prediction (trained_model , dataloader , dp = False , min_acc = 0.25 ):
99- if isinstance (trained_model , BoringModel ):
100- return _boring_model_run_prediction (trained_model , dataloader , dp , min_acc )
101- else :
102- return _eval_model_template_run_prediction (trained_model , dataloader , dp , min_acc )
103-
104-
105- def _eval_model_template_run_prediction (trained_model , dataloader , dp = False , min_acc = 0.50 ):
104+ def run_prediction_eval_model_template (trained_model , dataloader , dp = False , min_acc = 0.50 ):
106105 # run prediction on 1 batch
107106 batch = next (iter (dataloader ))
108107 x , y = batch
@@ -117,24 +116,6 @@ def _eval_model_template_run_prediction(trained_model, dataloader, dp=False, min
117116 else :
118117 with torch .no_grad ():
119118 y_hat = trained_model (x )
120- y_hat = y_hat .cpu ()
121-
122- # acc
123- labels_hat = torch .argmax (y_hat , dim = 1 )
124-
125- y = y .cpu ()
126- acc = torch .sum (y == labels_hat ).item () / (len (y ) * 1.0 )
127- acc = torch .tensor (acc )
128- acc = acc .item ()
119+ acc = accuracy (y_hat .cpu (), y .cpu ()).item ()
129120
130121 assert acc >= min_acc , f"This model is expected to get > { min_acc } in test set (it got { acc } )"
131-
132-
133- def _boring_model_run_prediction (trained_model , dataloader , dp = False , min_acc = 0.25 ):
134- # run prediction on 1 batch
135- batch = next (iter (dataloader ))
136- with torch .no_grad ():
137- output = trained_model (batch )
138- acc = trained_model .loss (batch , output )
139-
140- assert acc >= min_acc , f"This model is expected to get, { min_acc } in test set but got { acc } "
0 commit comments