Skip to content

Commit 0cd7796

Browse files
committed
Issues with checkpointing
1 parent a5c7cd5 commit 0cd7796

File tree

4 files changed

+60
-34
lines changed

4 files changed

+60
-34
lines changed

tests/base/boring_model.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,31 +106,6 @@ def training_step(self, batch, batch_idx):
106106
loss = self.loss(batch, output)
107107
return {"loss": loss}
108108

109-
def training_step_end(self, training_step_outputs):
110-
return training_step_outputs
111-
112-
def training_epoch_end(self, outputs) -> None:
113-
train_loss = torch.stack([x["loss"] for x in outputs]).mean()
114-
self.log('train_loss', train_loss)
115-
116-
def validation_step(self, batch, batch_idx):
117-
output = self.layer(batch)
118-
loss = self.loss(batch, output)
119-
return {"x": loss}
120-
121-
def validation_epoch_end(self, outputs) -> None:
122-
val_loss = torch.stack([x["x"] for x in outputs]).mean()
123-
self.log('val_loss', val_loss)
124-
125-
def test_step(self, batch, batch_idx):
126-
output = self.layer(batch)
127-
loss = self.loss(batch, output)
128-
return {"y": loss}
129-
130-
def test_epoch_end(self, outputs) -> None:
131-
test_loss = torch.stack([x["y"] for x in outputs]).mean()
132-
self.log('test_loss', test_loss)
133-
134109
def configure_optimizers(self):
135110
optimizer = getattr(torch.optim, self.optimizer_name)(self.layer.parameters(), lr=self.learning_rate)
136111
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

tests/base/develop_pipelines.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
6161
logger = get_default_logger(save_dir, version=version)
6262
trainer_options.update(logger=logger)
6363

64+
# TODO: DEPRECATED option
6465
if "checkpoint_callback" not in trainer_options:
6566
trainer_options.update(checkpoint_callback=True)
6667

@@ -71,7 +72,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
7172

7273
assert result == 1, "trainer failed"
7374
# Check that the model is actually changed post-training
74-
assert torch.norm(initial_values - post_train_values) > 0.1
75+
change_ratio = torch.norm(initial_values - post_train_values)
76+
assert change_ratio > 0.1, f"the model is changed of {change_ratio}"
7577

7678
# test model loading
7779
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))

tests/models/test_cpu.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,34 @@ def test_all_features_cpu_model(tmpdir):
115115

116116

117117
def test_early_stopping_cpu_model(tmpdir):
118-
"""Test each of the trainer options."""
118+
"""Test each of the trainer options. Simply test the combo trainer and
119+
model; callbacks functionality tests are in /tests/callbacks"""
120+
class ModelTrainVal(BoringModel):
121+
def __init__(self, *args, **kwargs):
122+
super().__init__(*args, **kwargs)
123+
124+
def validation_step(self, batch, batch_idx):
125+
output = self.layer(batch)
126+
loss = self.loss(batch, output)
127+
return {"x": loss}
128+
129+
def validation_epoch_end(self, outputs) -> None:
130+
val_loss = torch.stack([x["x"] for x in outputs]).mean()
131+
self.log('val_loss', val_loss)
132+
119133
stopping = EarlyStopping(monitor="val_loss", min_delta=0.1)
120134
trainer_options = dict(
121135
default_root_dir=tmpdir,
122136
callbacks=[stopping],
123137
max_epochs=2,
124-
gradient_clip_val=1.0,
125-
overfit_batches=0.20,
138+
gradient_clip_val=1,
126139
track_grad_norm=2,
127-
limit_train_batches=0.1,
140+
limit_train_batches=0.2,
128141
limit_val_batches=0.1,
129142
)
130143

131-
model = BoringModel()
144+
model = ModelTrainVal()
145+
132146
tpipes.run_model_test(trainer_options, model, on_gpu=False)
133147

134148
# test freeze on cpu
@@ -199,7 +213,29 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
199213

200214
def test_running_test_after_fitting(tmpdir):
201215
"""Verify test() on fitted model."""
202-
model = BoringModel()
216+
class ModelTrainValTest(BoringModel):
217+
def __init__(self, *args, **kwargs):
218+
super().__init__(*args, **kwargs)
219+
220+
def validation_step(self, batch, batch_idx):
221+
output = self.layer(batch)
222+
loss = self.loss(batch, output)
223+
return {"x": loss}
224+
225+
def validation_epoch_end(self, outputs) -> None:
226+
val_loss = torch.stack([x["x"] for x in outputs]).mean()
227+
self.log('val_loss', val_loss)
228+
229+
def test_step(self, batch, batch_idx):
230+
output = self.layer(batch)
231+
loss = self.loss(batch, output)
232+
return {"y": loss}
233+
234+
def test_epoch_end(self, outputs) -> None:
235+
test_loss = torch.stack([x["y"] for x in outputs]).mean()
236+
self.log('test_loss', test_loss)
237+
238+
model = ModelTrainValTest()
203239

204240
# logger file to get meta
205241
logger = tutils.get_default_logger(tmpdir)
@@ -230,7 +266,20 @@ def test_running_test_after_fitting(tmpdir):
230266

231267
def test_running_test_no_val(tmpdir):
232268
"""Verify `test()` works on a model with no `val_loader`."""
233-
model = BoringModel()
269+
class ModelTrainTest(BoringModel):
270+
def __init__(self, *args, **kwargs):
271+
super().__init__(*args, **kwargs)
272+
273+
def test_step(self, batch, batch_idx):
274+
output = self.layer(batch)
275+
loss = self.loss(batch, output)
276+
return {"y": loss}
277+
278+
def test_epoch_end(self, outputs) -> None:
279+
test_loss = torch.stack([x["y"] for x in outputs]).mean()
280+
self.log('test_loss', test_loss)
281+
282+
model = ModelTrainTest()
234283

235284
# logger file to get meta
236285
logger = tutils.get_default_logger(tmpdir)

tests/models/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_multi_gpu_none_backend(tmpdir):
4646
)
4747

4848
model = BoringModel()
49-
tpipes.run_model_test(trainer_options, model)
49+
tpipes.run_model_test(trainer_options, model, min_acc=0.20)
5050

5151

5252
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")

0 commit comments

Comments
 (0)