Skip to content

Commit 99a6161

Browse files
committed
Add a first batch of tests for Trainer.validate(…)
1 parent 860fef5 commit 99a6161

File tree

6 files changed

+240
-1
lines changed

6 files changed

+240
-1
lines changed

tests/trainer/test_config_validator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,23 @@ def test_test_loop_config(tmpdir):
9292
model = EvalModelTemplate(**hparams)
9393
model.test_step = None
9494
trainer.test(model, test_dataloaders=model.dataloader(train=False))
95+
96+
97+
def test_validation_loop_config(tmpdir):
98+
""""
99+
When either validation loop or validation data are missing
100+
"""
101+
hparams = EvalModelTemplate.get_default_hparams()
102+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
103+
104+
# has val loop but no val data
105+
with pytest.warns(UserWarning):
106+
model = EvalModelTemplate(**hparams)
107+
model.val_dataloader = None
108+
trainer.validate(model)
109+
110+
# has val data but no val loop
111+
with pytest.warns(UserWarning):
112+
model = EvalModelTemplate(**hparams)
113+
model.validation_step = None
114+
trainer.validate(model, val_dataloaders=model.dataloader(train=False))

tests/trainer/test_dataloaders.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,48 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
170170
trainer.test(ckpt_path=ckpt_path)
171171

172172

173+
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
174+
def test_multiple_validate_dataloader(tmpdir, ckpt_path):
175+
"""Verify multiple val_dataloaders."""
176+
177+
model_template = EvalModelTemplate()
178+
179+
class MultipleValDataloaderModel(EvalModelTemplate):
180+
def val_dataloader(self):
181+
return model_template.val_dataloader__multiple()
182+
183+
def validation_step(self, batch, batch_idx, *args, **kwargs):
184+
return model_template.validation_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs)
185+
186+
def validation_epoch_end(self, outputs):
187+
return model_template.validation_epoch_end__multiple_dataloaders(outputs)
188+
189+
model = MultipleValDataloaderModel()
190+
191+
# fit model
192+
trainer = Trainer(
193+
default_root_dir=tmpdir,
194+
max_epochs=1,
195+
limit_val_batches=0.1,
196+
limit_train_batches=0.2,
197+
)
198+
trainer.fit(model)
199+
if ckpt_path == 'specific':
200+
ckpt_path = trainer.checkpoint_callback.best_model_path
201+
trainer.validate(ckpt_path=ckpt_path)
202+
203+
# verify there are 2 test loaders
204+
assert len(trainer.val_dataloaders) == 2, \
205+
'Multiple val_dataloaders not initiated properly'
206+
207+
# make sure predictions are good for each test set
208+
for dataloader in trainer.val_dataloaders:
209+
tpipes.run_prediction(dataloader, trainer.model)
210+
211+
# run the test method
212+
trainer.validate(ckpt_path=ckpt_path)
213+
214+
173215
def test_train_dataloader_passed_to_fit(tmpdir):
174216
"""Verify that train dataloader can be passed to fit """
175217

tests/trainer/test_optimizers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,24 @@ def test_init_optimizers_during_testing(tmpdir):
335335
assert len(trainer.optimizer_frequencies) == 0
336336

337337

338+
def test_init_optimizers_during_validation(tmpdir):
339+
"""
340+
Test that optimizers is an empty list during validation.
341+
"""
342+
model = EvalModelTemplate()
343+
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
344+
345+
trainer = Trainer(
346+
default_root_dir=tmpdir,
347+
limit_test_batches=10
348+
)
349+
trainer.validate(model, ckpt_path=None)
350+
351+
assert len(trainer.lr_schedulers) == 0
352+
assert len(trainer.optimizers) == 0
353+
assert len(trainer.optimizer_frequencies) == 0
354+
355+
338356
def test_multiple_optimizers_callbacks(tmpdir):
339357
"""
340358
Tests that multiple optimizers can be used with callbacks

tests/trainer/test_states.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ class StateSnapshotCallback(Callback):
2323

2424
def __init__(self, snapshot_method: str):
2525
super().__init__()
26-
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
26+
assert snapshot_method in ['on_batch_start', 'on_validation_batch_start', 'on_test_batch_start']
2727
self.snapshot_method = snapshot_method
2828
self.trainer_state = None
2929

3030
def on_batch_start(self, trainer, pl_module):
3131
if self.snapshot_method == 'on_batch_start':
3232
self.trainer_state = trainer.state
3333

34+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
35+
if self.snapshot_method == 'on_validation_batch_start':
36+
self.trainer_state = trainer.state
37+
3438
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
3539
if self.snapshot_method == 'on_test_batch_start':
3640
self.trainer_state = trainer.state
@@ -191,6 +195,40 @@ def test_finished_state_after_test(tmpdir):
191195
assert trainer.state == TrainerState.FINISHED
192196

193197

198+
def test_running_state_during_validation(tmpdir):
199+
""" Tests that state is set to RUNNING during test """
200+
201+
hparams = EvalModelTemplate.get_default_hparams()
202+
model = EvalModelTemplate(**hparams)
203+
204+
snapshot_callback = StateSnapshotCallback(snapshot_method='on_validation_batch_start')
205+
206+
trainer = Trainer(
207+
callbacks=[snapshot_callback],
208+
default_root_dir=tmpdir,
209+
fast_dev_run=True,
210+
)
211+
212+
trainer.validate(model)
213+
214+
assert snapshot_callback.trainer_state == TrainerState.RUNNING
215+
216+
217+
def test_finished_state_after_validation(tmpdir):
218+
""" Tests that state is FINISHED after fit """
219+
hparams = EvalModelTemplate.get_default_hparams()
220+
model = EvalModelTemplate(**hparams)
221+
222+
trainer = Trainer(
223+
default_root_dir=tmpdir,
224+
fast_dev_run=True,
225+
)
226+
227+
trainer.validate(model)
228+
229+
assert trainer.state == TrainerState.FINISHED
230+
231+
194232
@pytest.mark.parametrize("extra_params", [
195233
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
196234
pytest.param(dict(max_steps=1), id='Single-Step'),

tests/trainer/test_trainer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,47 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
747747
assert trainer.tested_ckpt_path == ckpt_path
748748

749749

750+
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
751+
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
752+
def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k):
753+
hparams = EvalModelTemplate.get_default_hparams()
754+
755+
model = EvalModelTemplate(**hparams)
756+
trainer = Trainer(
757+
max_epochs=2,
758+
progress_bar_refresh_rate=0,
759+
default_root_dir=tmpdir,
760+
checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k),
761+
)
762+
trainer.fit(model)
763+
if ckpt_path == "best":
764+
# ckpt_path is 'best', meaning we load the best weights
765+
if save_top_k == 0:
766+
with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"):
767+
trainer.validate(ckpt_path=ckpt_path)
768+
else:
769+
trainer.validate(ckpt_path=ckpt_path)
770+
assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
771+
elif ckpt_path is None:
772+
# ckpt_path is None, meaning we don't load any checkpoints and
773+
# use the weights from the end of training
774+
trainer.validate(ckpt_path=ckpt_path)
775+
assert trainer.tested_ckpt_path is None
776+
else:
777+
# specific checkpoint, pick one from saved ones
778+
if save_top_k == 0:
779+
with pytest.raises(FileNotFoundError):
780+
trainer.validate(ckpt_path="random.ckpt")
781+
else:
782+
ckpt_path = str(
783+
list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir())[
784+
0
785+
].absolute()
786+
)
787+
trainer.validate(ckpt_path=ckpt_path)
788+
assert trainer.tested_ckpt_path == ckpt_path
789+
790+
750791
def test_disabled_training(tmpdir):
751792
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
752793

@@ -1448,6 +1489,10 @@ def setup(self, model, stage):
14481489
assert trainer.stage == "test"
14491490
assert trainer.get_model().stage == "test"
14501491

1492+
trainer.validate(ckpt_path=None)
1493+
assert trainer.stage == "validation"
1494+
assert trainer.get_model().stage == "validation"
1495+
14511496

14521497
@pytest.mark.parametrize(
14531498
"train_batches, max_steps, log_interval",
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
import torch
16+
17+
import pytorch_lightning as pl
18+
import tests.base.develop_utils as tutils
19+
from tests.base import EvalModelTemplate
20+
21+
22+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
23+
def test_single_gpu_validate(tmpdir):
24+
tutils.set_random_master_port()
25+
26+
model = EvalModelTemplate()
27+
trainer = pl.Trainer(
28+
default_root_dir=tmpdir,
29+
max_epochs=2,
30+
limit_train_batches=10,
31+
limit_val_batches=10,
32+
gpus=[0],
33+
)
34+
trainer.fit(model)
35+
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
36+
results = trainer.validate()
37+
assert 'val_acc' in results[0]
38+
39+
old_weights = model.c_d1.weight.clone().detach().cpu()
40+
41+
results = trainer.validate(model)
42+
assert 'val_acc' in results[0]
43+
44+
# make sure weights didn't change
45+
new_weights = model.c_d1.weight.clone().detach().cpu()
46+
47+
assert torch.all(torch.eq(old_weights, new_weights))
48+
49+
50+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
51+
def test_ddp_spawn_validate(tmpdir):
52+
tutils.set_random_master_port()
53+
54+
model = EvalModelTemplate()
55+
trainer = pl.Trainer(
56+
default_root_dir=tmpdir,
57+
max_epochs=2,
58+
limit_train_batches=10,
59+
limit_val_batches=10,
60+
gpus=[0, 1],
61+
distributed_backend='ddp_spawn',
62+
)
63+
trainer.fit(model)
64+
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
65+
results = trainer.validate()
66+
assert 'val_acc' in results[0]
67+
68+
old_weights = model.c_d1.weight.clone().detach().cpu()
69+
70+
results = trainer.validate(model)
71+
assert 'val_acc' in results[0]
72+
73+
# make sure weights didn't change
74+
new_weights = model.c_d1.weight.clone().detach().cpu()
75+
76+
assert torch.all(torch.eq(old_weights, new_weights))

0 commit comments

Comments
 (0)