Skip to content

Commit f7a189f

Browse files
awaelchlicarmocca
andcommitted
Reset all results on epoch end (#14061)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 695a01e commit f7a189f

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Fixed
1010

11-
-
11+
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
1212

1313

1414
## [1.7.1] - 2022-08-09

src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ def update_train_epoch_metrics(self) -> None:
163163
self.log_metrics(self.metrics["log"])
164164

165165
# reset result collection for next epoch
166-
assert self.trainer._results is not None
167-
self.trainer._results.reset(metrics=True)
166+
self.reset_results()
168167

169168
"""
170169
Utilities and properties

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,12 @@ def on_train_epoch_end(self, trainer, pl_module):
569569
"accelerator",
570570
[
571571
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
572+
"cpu",
572573
],
573574
)
574575
def test_metric_are_properly_reduced(tmpdir, accelerator):
575576
class TestingModel(BoringModel):
576-
def __init__(self, *args, **kwargs) -> None:
577+
def __init__(self) -> None:
577578
super().__init__()
578579
self.val_acc = Accuracy()
579580

@@ -592,7 +593,6 @@ def validation_step(self, batch, batch_idx):
592593
return super().validation_step(batch, batch_idx)
593594

594595
early_stop = EarlyStopping(monitor="val_acc", mode="max")
595-
596596
checkpoint = ModelCheckpoint(monitor="val_acc", save_last=True, save_top_k=2, mode="max")
597597

598598
model = TestingModel()
@@ -812,3 +812,28 @@ def training_step(self, batch, batch_idx):
812812
call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3),
813813
]
814814
)
815+
816+
817+
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
818+
def test_log_on_train_start(mock_log_metrics, tmpdir):
819+
"""Tests that logged metrics on_train_start get reset after the first epoch."""
820+
821+
class MyModel(BoringModel):
822+
def on_train_start(self):
823+
self.log("foo", 123)
824+
825+
model = MyModel()
826+
trainer = Trainer(
827+
default_root_dir=tmpdir,
828+
limit_train_batches=1,
829+
limit_val_batches=0,
830+
max_epochs=2,
831+
log_every_n_steps=1,
832+
enable_model_summary=False,
833+
enable_checkpointing=False,
834+
enable_progress_bar=False,
835+
)
836+
trainer.fit(model)
837+
838+
assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
839+
assert trainer.max_epochs > 1

0 commit comments

Comments
 (0)