Skip to content

Commit 53cdd58

Browse files
committed
Fixed tests for logging and checkpointing
1 parent 0cd7796 commit 53cdd58

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

tests/base/boring_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,25 @@ def training_step(self, batch, batch_idx):
106106
loss = self.loss(batch, output)
107107
return {"loss": loss}
108108

109+
def training_epoch_end(self, outputs) -> None:
110+
torch.stack([x["loss"] for x in outputs]).mean()
111+
112+
def validation_step(self, batch, batch_idx):
113+
output = self.layer(batch)
114+
loss = self.loss(batch, output)
115+
return {"x": loss}
116+
117+
def validation_epoch_end(self, outputs) -> None:
118+
torch.stack([x['x'] for x in outputs]).mean()
119+
120+
def test_step(self, batch, batch_idx):
121+
output = self.layer(batch)
122+
loss = self.loss(batch, output)
123+
return {"y": loss}
124+
125+
def test_epoch_end(self, outputs) -> None:
126+
torch.stack([x["y"] for x in outputs]).mean()
127+
109128
def configure_optimizers(self):
110129
optimizer = getattr(torch.optim, self.optimizer_name)(self.layer.parameters(), lr=self.learning_rate)
111130
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

tests/models/test_cpu.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,6 @@ class ModelTrainVal(BoringModel):
121121
def __init__(self, *args, **kwargs):
122122
super().__init__(*args, **kwargs)
123123

124-
def validation_step(self, batch, batch_idx):
125-
output = self.layer(batch)
126-
loss = self.loss(batch, output)
127-
return {"x": loss}
128-
129124
def validation_epoch_end(self, outputs) -> None:
130125
val_loss = torch.stack([x["x"] for x in outputs]).mean()
131126
self.log('val_loss', val_loss)
@@ -270,6 +265,9 @@ class ModelTrainTest(BoringModel):
270265
def __init__(self, *args, **kwargs):
271266
super().__init__(*args, **kwargs)
272267

268+
def val_loader(self):
269+
pass
270+
273271
def test_step(self, batch, batch_idx):
274272
output = self.layer(batch)
275273
loss = self.loss(batch, output)

0 commit comments

Comments
 (0)