Skip to content

Commit cb12b58

Browse files
author
SeanNaren
committed
Clean up override test, fix function name
1 parent 6d24bd4 commit cb12b58

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

tests/models/test_hooks.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,9 @@ def training_epoch_end(self, outputs):
9494

9595
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
9696
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
97-
num_epochs = 1
9897

9998
class LoggingCallback(pl.Callback):
100-
101-
def on_train_epoch_end(self, trainer, pl_module):
99+
def on_train_epoch_start(self, trainer, pl_module):
102100
self.len_outputs = 0
103101

104102
def on_train_epoch_end(self, trainer, pl_module, outputs):
@@ -110,7 +108,6 @@ def on_train_epoch_start(self):
110108
self.num_train_batches = 0
111109

112110
def training_epoch_end(self, outputs): # Overridden
113-
pass
114111
return
115112

116113
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
@@ -129,19 +126,20 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
129126

130127
callback = LoggingCallback()
131128
trainer = Trainer(
132-
max_epochs=num_epochs,
129+
max_epochs=1,
133130
default_root_dir=tmpdir,
134131
overfit_batches=2,
135132
callbacks=[callback],
136133
)
137134

138-
result = trainer.fit(overridden_model)
135+
trainer.fit(overridden_model)
136+
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook
137+
# if training_epoch_end is overridden
139138
assert callback.len_outputs == overridden_model.num_train_batches
140-
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
141139

142-
result = trainer.fit(not_overridden_model)
143-
assert callback.len_outputs == 0
140+
trainer.fit(not_overridden_model)
144141
# outputs from on_train_batch_end should be empty
142+
assert callback.len_outputs == 0
145143

146144

147145
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")

0 commit comments

Comments
 (0)