Skip to content

Validation metrics assumed to be logged within the first training epoch #6791

@tmcclintock

Description

@tmcclintock

🐛 Bug

In TrainLoop.on_train_end a call to check_checkpoint_callback is made. Within that method a call to on_validation_end is performed. As per the docs (and the fact that the ModelCheckpoint fires on on_validation_end), the expectation is to monitor validation metrics. However, if in the Trainer we set num_sanity_val_steps to 0 then validation metrics are never logged, resulting in a misconfiguration exception in _validate_monitor_key.

Note that this is only an issue on the first epoch -- after this the val keys appear in the callback metrics and this issue is moot.

Please reproduce using the BoringModel

To Reproduce

Use following BoringModel and post here

I cannot reproduce this with the BoringModel since it uses deprecated x_step methods (e.g. validation_step returns the loss rather than logs it). It should be updated to 1.2.6 in a different issue.

Expected behavior

If the model checkpoint only implements on_validation_end then it should only fire on that callback, not secretly in on_train_end. If it should fire in on_train_end it should either have a second monitor specific to the callback_metrics logged during training, or its logic should be moved out from under on_validation_end to a more general (less misleading) hook.

Note that the callbacks have access to the Trainer.state, so it is possible to move the ModelCheckpoint.on_validation_end logic into a higher level hook and leverage this state info. An elegant (imo) attribute to add to ModelCheckpoint could be monitor_state, so that for instance a user can say "monitor metric 'loss' but only while the trainer is in state 'train'".

class ModelCheckpoint(Callback):
    def __init__(
        self,
                ...
                monitor: Optional[str] = None,
                monitor_state: Optional[Union[str, List[str]] = None,  # must a subset of fit/validate/test/predict/etc.
                ...
        ):
                ...

Environment

On PL master (1.2.6)

  • PyTorch Version (e.g., 1.0): 1.7.1
  • OS (e.g., Linux): linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.7
  • CUDA/cuDNN version: 10.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onpriority: 1Medium priority taskwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions