-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
I think I found a bug, where errors probably caused by users are misrepresented as checkpointing MisconfigurationException even though the checkpointing is configured correctly.
This happens when errors are raised within training (such as RuntimeErrors or CUDA-OOM errors) and bubble up to the try/except command in def train(self) function, which can be found here: https://github.com/PyTorchLightning/pytorch-lightning/blob/65247ec47cbe8f0254f444f409025585806f9113/pytorch_lightning/trainer/trainer.py#L550
As these errors are not excepted here, the code continues to execute the code in finally. This calls self.train_loop.on_train_end() which proceeds to save a checkpoint. If one monitors on a validation metric, such as val/accuracy, no value has been saved as the error occured during the training. Thus, in turn a MisconfigurationException is raised by the checkpointing code, stating that the monitored metric is not found in the returned metrics.
To Reproduce
- configure checkpointing to monitor a validation metric, which (by definition) must not be logged during training.
- raise any error during training, such as
RuntimeErrorbut notKeybordInterrupt(that is the only one which is caught).
Expected behavior
The error should not be glossed over by the finally statement and be raised all the way to the top, so the user can see it and fix the underlying bug.
Environment
- OS: Linux
- env-type: conda
- pytorch-lightning 1.1.4
- pytorch 1.7.0
- python 3.8.6
- CUDA/cuDNN version: 11.0
- GPU models and configuration: One v100, 32GB vRAM
Lazy patch:
A very hacky way of at least letting the user know about their error, is to modify train such that after except KeyboardInterrupt all other errors are caught with except Exception as e and then immediately printing the error with print(e). Unfortunately raising the error with raise e does not work because the finally code is executed first, raising its own MisconfigurationException.
The full code of train would look like this (changes are below the excepted KeyboardInterrupt):
def train(self):
self.run_sanity_check(self.get_model())
# set stage for logging
self.logger_connector.set_stage("train")
self.checkpoint_connector.has_trained = False
# enable train mode
model = self.get_model()
model.train()
torch.set_grad_enabled(True)
# reload data when needed
self.train_loop.reset_train_val_dataloaders(model)
# hook
self.train_loop.on_train_start()
try:
if self.train_loop.should_skip_training():
return
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# hook
self.train_loop.on_train_epoch_start(epoch)
with self.profiler.profile("run_training_epoch"):
# run train epoch
self.train_loop.run_training_epoch()
if self.max_steps and self.max_steps <= self.global_step:
return
# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')
# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if self.should_stop:
if met_min_epochs and met_min_steps:
return
log.info(
'Trainer was signaled to stop but required minimum epochs'
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
' not been met. Training will continue...'
)
except KeyboardInterrupt:
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self._state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()
except Exception as e:
print(e) # better would be raising it, but that would be executed after the finally
finally:
# hook
self.train_loop.on_train_end()Test
A test to catch this problem in the future would be to checkpoint on a validation metric, raise any error in the training (except KeyboardInterrupt which is excepted) and assert that it bubbles up all the way.