-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
In TrainLoop.on_train_end a call to check_checkpoint_callback is made. Within that method a call to on_validation_end is performed. As per the docs (and the fact that the ModelCheckpoint fires on on_validation_end), the expectation is to monitor validation metrics. However, if in the Trainer we set num_sanity_val_steps to 0 then validation metrics are never logged, resulting in a misconfiguration exception in _validate_monitor_key.
Note that this is only an issue on the first epoch -- after this the val keys appear in the callback metrics and this issue is moot.
Please reproduce using the BoringModel
To Reproduce
Use following BoringModel and post here
I cannot reproduce this with the BoringModel since it uses deprecated x_step methods (e.g. validation_step returns the loss rather than logs it). It should be updated to 1.2.6 in a different issue.
Expected behavior
If the model checkpoint only implements on_validation_end then it should only fire on that callback, not secretly in on_train_end. If it should fire in on_train_end it should either have a second monitor specific to the callback_metrics logged during training, or its logic should be moved out from under on_validation_end to a more general (less misleading) hook.
Note that the callbacks have access to the Trainer.state, so it is possible to move the ModelCheckpoint.on_validation_end logic into a higher level hook and leverage this state info. An elegant (imo) attribute to add to ModelCheckpoint could be monitor_state, so that for instance a user can say "monitor metric 'loss' but only while the trainer is in state 'train'".
class ModelCheckpoint(Callback):
def __init__(
self,
...
monitor: Optional[str] = None,
monitor_state: Optional[Union[str, List[str]] = None, # must a subset of fit/validate/test/predict/etc.
...
):
...Environment
On PL master (1.2.6)
- PyTorch Version (e.g., 1.0): 1.7.1
- OS (e.g., Linux): linux
- How you installed PyTorch (
conda,pip, source): conda - Build command you used (if compiling from source): N/A
- Python version: 3.7
- CUDA/cuDNN version: 10.2