1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import os
15+ from unittest import mock
16+
1417import pytest
1518import torch
19+ import torch .nn .functional as F
1620
1721import pytorch_lightning as pl
1822import tests .helpers .pipelines as tpipes
1923import tests .helpers .utils as tutils
2024from pytorch_lightning .callbacks import EarlyStopping
2125from pytorch_lightning .core import memory
22- from tests .base import EvalModelTemplate
26+ from tests .helpers import BoringModel
27+ from tests .helpers .datamodules import ClassifDataModule
28+ from tests .helpers .simple_models import ClassificationModel
2329
2430PRETEND_N_OF_GPUS = 16
2531
2632
33+ class CustomClassificationModelDP (ClassificationModel ):
34+
35+ def _step (self , batch , batch_idx ):
36+ x , y = batch
37+ logits = self (x )
38+ return {'logits' : logits , 'y' : y }
39+
40+ def training_step (self , batch , batch_idx ):
41+ out = self ._step (batch , batch_idx )
42+ loss = F .cross_entropy (out ['logits' ], out ['y' ])
43+ return loss
44+
45+ def validation_step (self , batch , batch_idx ):
46+ return self ._step (batch , batch_idx )
47+
48+ def test_step (self , batch , batch_idx ):
49+ return self ._step (batch , batch_idx )
50+
51+ def validation_step_end (self , outputs ):
52+ self .log ('val_acc' , self .valid_acc (outputs ['logits' ], outputs ['y' ]))
53+
54+ def test_step_end (self , outputs ):
55+ self .log ('test_acc' , self .test_acc (outputs ['logits' ], outputs ['y' ]))
56+
57+
2758@pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
2859def test_multi_gpu_early_stop_dp (tmpdir ):
2960 """Make sure DDP works. with early stopping"""
3061 tutils .set_random_master_port ()
3162
63+ dm = ClassifDataModule ()
64+ model = CustomClassificationModelDP ()
65+
3266 trainer_options = dict (
3367 default_root_dir = tmpdir ,
34- callbacks = [EarlyStopping ()],
68+ callbacks = [EarlyStopping (monitor = 'val_acc' )],
3569 max_epochs = 50 ,
3670 limit_train_batches = 10 ,
3771 limit_val_batches = 10 ,
3872 gpus = [0 , 1 ],
3973 accelerator = 'dp' ,
4074 )
4175
42- model = EvalModelTemplate ()
43- tpipes .run_model_test (trainer_options , model )
76+ tpipes .run_model_test (trainer_options , model , dm )
4477
4578
4679@pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
@@ -57,22 +90,21 @@ def test_multi_gpu_model_dp(tmpdir):
5790 progress_bar_refresh_rate = 0 ,
5891 )
5992
60- model = EvalModelTemplate ()
93+ model = BoringModel ()
6194
6295 tpipes .run_model_test (trainer_options , model )
6396
6497 # test memory helper functions
6598 memory .get_memory_profile ('min_max' )
6699
67100
101+ @mock .patch .dict (os .environ , {"CUDA_VISIBLE_DEVICES" : "0,1" })
68102@pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
69103def test_dp_test (tmpdir ):
70104 tutils .set_random_master_port ()
71105
72- import os
73- os .environ ['CUDA_VISIBLE_DEVICES' ] = '0,1'
74-
75- model = EvalModelTemplate ()
106+ dm = ClassifDataModule ()
107+ model = CustomClassificationModelDP ()
76108 trainer = pl .Trainer (
77109 default_root_dir = tmpdir ,
78110 max_epochs = 2 ,
@@ -81,17 +113,17 @@ def test_dp_test(tmpdir):
81113 gpus = [0 , 1 ],
82114 accelerator = 'dp' ,
83115 )
84- trainer .fit (model )
116+ trainer .fit (model , datamodule = dm )
85117 assert 'ckpt' in trainer .checkpoint_callback .best_model_path
86- results = trainer .test ()
118+ results = trainer .test (datamodule = dm )
87119 assert 'test_acc' in results [0 ]
88120
89- old_weights = model .c_d1 .weight .clone ().detach ().cpu ()
121+ old_weights = model .layer_0 .weight .clone ().detach ().cpu ()
90122
91- results = trainer .test (model )
123+ results = trainer .test (model , datamodule = dm )
92124 assert 'test_acc' in results [0 ]
93125
94126 # make sure weights didn't change
95- new_weights = model .c_d1 .weight .clone ().detach ().cpu ()
127+ new_weights = model .layer_0 .weight .clone ().detach ().cpu ()
96128
97129 assert torch .all (torch .eq (old_weights , new_weights ))
0 commit comments