Skip to content

[RFC] Deprecate the _epoch_end hooks #8731

@ananthsub

Description

@ananthsub

We are auditing the Lightning components and APIs to assess opportunities for improvements:

#7740
https://docs.google.com/document/d/1xHU7-iQSpp9KJTjI3As2EM0mfNHHr37WZYpDpwLkivA/edit#

Lightning has had some recent issues filed around these hooks:

  • training_epoch_end
  • validation_epoch_end
  • test_epoch_end
  • predict_epoch_end

Examples:

These hooks exist in order to accumulate the step-level outputs during the epoch for post-processing at the end of the epoch. However, we do not need these to be on the core LightningModule interface. Users can easily track outputs directly inside their implemented modules

Asking users to do this tracking offers major benefits:

  1. We avoid API confusion: for instance, when should users implement something in training_epoch_end vs on_train_epoch_end ? this can improve the onboarding experience (one less class of hooks to learn about, only 1 way to do things).
  2. This can also improve performance: if users implement something in training_epoch_end and don't use outputs, the trainer needlessly accumulates results, which wastes memory and risks OOMing. This is slowdown is not clearly visible to the user either, unless training completely fails, at which point this is a bad user experience.
  3. Reduced API surface area for the trainer reduces the risk of bugs like this. These bugs disproportionately hurt user trust because the control flow isn't visible to end users. Going the other way, removing this class of bugs has a disproportionate benefit to user trust.
  4. The current contract makes the trainer responsible for stewardship of data it doesn't directly use. Removing this support clarifies responsbilities and simplifies the loop internals.
  5. There's less "framework magic" at play, which means more readable user code because this tracking is explicit.
  6. Because the tracking is explicit, the responsibility of testing also falls to users, and in general we must encourage users to be able to test their code, and that the framework remains easily testable.

Cons:

  1. (marginally) more boilerplate code in LightningModules. For instance, users would need to pay attention to resetting the accumulated outputs (unless they explicitly want to accumulate results across epochs).

Proposal

  • Deprecate training_epoch_end, validation_epoch_end, andtest_epoch_end in v1.5
  • Remove these hooks entirely, and their corresponding calls in the loops in v1.7

This is how easily users can implement this in their LightningModule with the existing hooks:

class MyModel(LightningModule):
    def __init__(self):
        self._train_outputs = []  # <----- New
    
    def training_step(self, *args, **kwargs):
        ...
        output = ...
        self._train_outputs.append(output)  # <----- New
        return output
    
    def on_train_epoch_end(self) -> None:
        # process self._train_outputs
        self._train_outputs = [] # <----- New

so we're talking about 3 lines of code here per train/val/test/predict. I argue this is so minimal compared to the amount of logic that usually goes into post-processing the outputs anyways.

@PyTorchLightning/core-contributors

Originally posted by @ananthsub in #8690

Metadata

Metadata

Assignees

No one assigned

    Labels

    deprecationIncludes a deprecationdesignIncludes a design discussiondiscussionIn a discussion stage

    Type

    No type

    Projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions