Skip to content

Errors within try/except of train(self) are misrepresented as checkpointing MisconfigurationException #5766

@l-salewski

Description

@l-salewski

🐛 Bug

I think I found a bug, where errors probably caused by users are misrepresented as checkpointing MisconfigurationException even though the checkpointing is configured correctly.

This happens when errors are raised within training (such as RuntimeErrors or CUDA-OOM errors) and bubble up to the try/except command in def train(self) function, which can be found here: https://github.com/PyTorchLightning/pytorch-lightning/blob/65247ec47cbe8f0254f444f409025585806f9113/pytorch_lightning/trainer/trainer.py#L550
As these errors are not excepted here, the code continues to execute the code in finally. This calls self.train_loop.on_train_end() which proceeds to save a checkpoint. If one monitors on a validation metric, such as val/accuracy, no value has been saved as the error occured during the training. Thus, in turn a MisconfigurationException is raised by the checkpointing code, stating that the monitored metric is not found in the returned metrics.

To Reproduce

  • configure checkpointing to monitor a validation metric, which (by definition) must not be logged during training.
  • raise any error during training, such as RuntimeError but not KeybordInterrupt (that is the only one which is caught).

Expected behavior

The error should not be glossed over by the finally statement and be raised all the way to the top, so the user can see it and fix the underlying bug.

Environment

  • OS: Linux
  • env-type: conda
  • pytorch-lightning 1.1.4
  • pytorch 1.7.0
  • python 3.8.6
  • CUDA/cuDNN version: 11.0
  • GPU models and configuration: One v100, 32GB vRAM

Lazy patch:

A very hacky way of at least letting the user know about their error, is to modify train such that after except KeyboardInterrupt all other errors are caught with except Exception as e and then immediately printing the error with print(e). Unfortunately raising the error with raise e does not work because the finally code is executed first, raising its own MisconfigurationException.

The full code of train would look like this (changes are below the excepted KeyboardInterrupt):

    def train(self):
        self.run_sanity_check(self.get_model())

        # set stage for logging
        self.logger_connector.set_stage("train")

        self.checkpoint_connector.has_trained = False

        # enable train mode
        model = self.get_model()
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            if self.train_loop.should_skip_training():
                return
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                with self.profiler.profile("run_training_epoch"):
                    # run train epoch
                    self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:
                    return

                # update LR schedulers
                self.optimizer_connector.update_learning_rates(interval='epoch')

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if met_min_epochs and met_min_steps:
                        return
                    log.info(
                        'Trainer was signaled to stop but required minimum epochs'
                        f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                        ' not been met. Training will continue...'
                    )

        except KeyboardInterrupt:
            rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()
        except Exception as e:
            print(e)  # better would be raising it, but that would be executed after the finally
        finally:
            # hook
            self.train_loop.on_train_end()

Test

A test to catch this problem in the future would be to checkpoint on a validation metric, raise any error in the training (except KeyboardInterrupt which is excepted) and assert that it bubbles up all the way.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions