1919import tests .helpers .utils as tutils
2020from pytorch_lightning .callbacks import EarlyStopping
2121from pytorch_lightning .core import memory
22- from tests .base import EvalModelTemplate
2322from tests .helpers import BoringModel
2423from tests .helpers .datamodules import ClassifDataModule
2524from tests .helpers .simple_models import ClassificationModel
@@ -76,7 +75,8 @@ def test_dp_test(tmpdir):
7675 import os
7776 os .environ ['CUDA_VISIBLE_DEVICES' ] = '0,1'
7877
79- model = EvalModelTemplate ()
78+ dm = ClassifDataModule ()
79+ model = ClassificationModel ()
8080 trainer = pl .Trainer (
8181 default_root_dir = tmpdir ,
8282 max_epochs = 2 ,
@@ -85,14 +85,14 @@ def test_dp_test(tmpdir):
8585 gpus = [0 , 1 ],
8686 accelerator = 'dp' ,
8787 )
88- trainer .fit (model )
88+ trainer .fit (model , datamodule = dm )
8989 assert 'ckpt' in trainer .checkpoint_callback .best_model_path
90- results = trainer .test ()
90+ results = trainer .test (datamodule = dm )
9191 assert 'test_acc' in results [0 ]
9292
9393 old_weights = model .c_d1 .weight .clone ().detach ().cpu ()
9494
95- results = trainer .test (model )
95+ results = trainer .test (model , datamodule = dm )
9696 assert 'test_acc' in results [0 ]
9797
9898 # make sure weights didn't change
0 commit comments