Skip to content

Metrics reduction during logging/checkpointing #5146

@justusschock

Description

@justusschock

🐛 Bug

When logging and checkpointing/early stopping with the metrics like shown in the code below, I get:

Traceback (most recent call last):
  File "pytorch_lightning/trainer/trainer.py", line 521, in train
    self.train_loop.run_training_epoch()
  File "pytorch_lightning/trainer/training_loop.py", line 590, in run_training_epoch
    self.trainer.run_evaluation(test_mode=False)
  File "pytorch_lightning/trainer/trainer.py", line 628, in run_evaluation
    self.evaluation_loop.on_evaluation_end()
  File "pytorch_lightning/trainer/evaluation_loop.py", line 111, in on_evaluation_end
    self.trainer.call_hook('on_validation_end', *args, **kwargs)
  File "pytorch_lightning/trainer/trainer.py", line 887, in call_hook
    trainer_hook(*args, **kwargs)
  File "pytorch_lightning/trainer/callback_hook.py", line 177, in on_validation_end
    callback.on_validation_end(self, self.get_model())
  File "pytorch_lightning/callbacks/early_stopping.py", line 164, in on_validation_end
    self._run_early_stopping_check(trainer, pl_module)
  File "pytorch_lightning/callbacks/early_stopping.py", line 205, in _run_early_stopping_check
    current = torch.tensor(current, device=pl_module.device)
RuntimeError: Could not infer dtype of Accuracy

Please reproduce using the BoringModel and post here

from tests.base.boring_model import BoringModel, RandomDataset
import torch
import pytorch_lightning as pl


class CustomModel(BoringModel):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.acc = pl.metrics.Accuracy()
    def training_step(self, batch, batch_idx):
        self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
        self.log('train_acc', self.acc, on_step=True, on_epoch=True)
        return super().training_step(batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
        self.log('val_acc', self.acc, on_step=True, on_epoch=True)
        return super().validation_step(batch, batch_idx)

if __name__ == '__main__':
    early_stop = pl.callbacks.EarlyStopping(monitor='val_acc', mode='max')
    checkpoint = pl.callbacks.ModelCheckpoint(
        monitor='val_acc',
        save_last=True,
        save_top_k=5,
        mode='max',
)

    pl.Trainer(gpus=[0,], max_epochs=20, callbacks=[early_stop, checkpoint]).fit(CustomModel(), torch.utils.data.DataLoader(RandomDataset(32, 500)))

Environment

pl version: 1.1.0

Additional context

Sometimes I also get

  File "pytorch_lightning/trainer/trainer.py", line 470, in fit
    results = self.accelerator_backend.train()
  File "pytorch_lightning/accelerators/gpu_accelerator.py", line 66, in train
    results = self.train_or_test()
  File "pytorch_lightning/accelerators/accelerator.py", line 65, in train_or_test
    results = self.trainer.train()
  File "pytorch_lightning/trainer/trainer.py", line 521, in train
    self.train_loop.run_training_epoch()
  File "pytorch_lightning/trainer/training_loop.py", line 627, in run_training_epoch
    self.run_on_epoch_end_hook(epoch_output)
  File "pytorch_lightning/trainer/training_loop.py", line 856, in run_on_epoch_end_hook
    self.trainer.logger_connector.on_train_epoch_end()
  File "pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py", line 367, in on_train_epoch_end
    self.cached_results.has_batch_loop_finished = True
  File "pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py", line 441, in has_batch_loop_finished
    self.auto_reduce_results_on_epoch_end()
  File "pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py", line 431, in auto_reduce_results_on_epoch_end
    hook_result.auto_reduce_results_on_epoch_end()
  File "pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py", line 217, in auto_reduce_results_on_epoch_end
    tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
  File pytorch_lightning/core/step_result.py", line 566, in reduce_across_time
    value = torch.tensor(value)
RuntimeError: Could not infer dtype of ABCMeta

which I could not reproduce with open code so far, but I assume it to be caused by the same issue.

cc @tchaton

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcheckpointingRelated to checkpointinghelp wantedOpen to be worked onloggingRelated to the `LoggerConnector` and `log()`priority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions