From c8ef4c6ef8bebe584b8d4bc3feedfa71fa61ac1c Mon Sep 17 00:00:00 2001 From: Manan Goel Date: Fri, 24 Jun 2022 00:35:52 +0000 Subject: [PATCH 1/2] Fixed the bug by changing step variable to trainer.global_step --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ff882912625d0..702949b53a2f5 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -99,7 +99,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: if step is None: # added metrics for convenience scalar_metrics.setdefault("epoch", self.trainer.current_epoch) - step = self.trainer.fit_loop.epoch_loop._batches_that_stepped + step = self.trainer.global_step # log actual metrics for logger in self.trainer.loggers: From 8478ca68a51041e5f04da9f43e64e2723bcaed66 Mon Sep 17 00:00:00 2001 From: Manan Goel Date: Fri, 24 Jun 2022 01:23:30 +0000 Subject: [PATCH 2/2] Updated tests according to new behaviour of the step argument and updated step to trainer.global_step - 1 to make it 0 based indexing --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- tests/tests_pytorch/loggers/test_all.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 702949b53a2f5..a582f0e8ab910 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -99,7 +99,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: if step is None: # added metrics for convenience scalar_metrics.setdefault("epoch", self.trainer.current_epoch) - step = self.trainer.global_step + step = self.trainer.global_step - 1 # log actual metrics for logger in self.trainer.loggers: diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 96d1016cc612b..f77d27a5ec649 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -144,14 +144,14 @@ def log_metrics(self, metrics, step): expected = [ (0, ["epoch", "train_some_val"]), (0, ["early_stop_on", "epoch", "val_loss"]), - (1, ["epoch", "test_loss"]), + (0, ["epoch", "test_loss"]), ] assert log_metric_names == expected else: expected = [ (0, ["epoch", "train_some_val"]), (0, ["early_stop_on", "epoch", "val_loss"]), - (1, ["epoch", "test_loss"]), + (0, ["epoch", "test_loss"]), ] assert log_metric_names == expected