-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
We are auditing the Lightning components and APIs to assess opportunities for improvements:
#7740
https://docs.google.com/document/d/1xHU7-iQSpp9KJTjI3As2EM0mfNHHr37WZYpDpwLkivA/edit#
Lightning has had some recent issues filed around these hooks:
training_epoch_end
validation_epoch_end
test_epoch_end
predict_epoch_end
Examples:
- outputs in training_epoch_end contain only outputs from last batch repeated #8603
- Add predict_epoch_end and predict_step_end methods to LightningModule #8657
- Ordering of hooks #8670
- Inconsistent API for on_predict_epoch_end #8479
- Cuda out of memory #11582
- RFC: Change input format of
training_epoch_end
hook in case of multiple optimizers or TBPTT #9737
These hooks exist in order to accumulate the step-level outputs during the epoch for post-processing at the end of the epoch. However, we do not need these to be on the core LightningModule interface. Users can easily track outputs directly inside their implemented modules
Asking users to do this tracking offers major benefits:
- We avoid API confusion: for instance, when should users implement something in
training_epoch_end
vson_train_epoch_end
? this can improve the onboarding experience (one less class of hooks to learn about, only 1 way to do things). - This can also improve performance: if users implement something in
training_epoch_end
and don't useoutputs
, the trainer needlessly accumulates results, which wastes memory and risks OOMing. This is slowdown is not clearly visible to the user either, unless training completely fails, at which point this is a bad user experience. - Reduced API surface area for the trainer reduces the risk of bugs like this. These bugs disproportionately hurt user trust because the control flow isn't visible to end users. Going the other way, removing this class of bugs has a disproportionate benefit to user trust.
- The current contract makes the trainer responsible for stewardship of data it doesn't directly use. Removing this support clarifies responsbilities and simplifies the loop internals.
- There's less "framework magic" at play, which means more readable user code because this tracking is explicit.
- Because the tracking is explicit, the responsibility of testing also falls to users, and in general we must encourage users to be able to test their code, and that the framework remains easily testable.
Cons:
- (marginally) more boilerplate code in LightningModules. For instance, users would need to pay attention to resetting the accumulated outputs (unless they explicitly want to accumulate results across epochs).
Proposal
- Deprecate
training_epoch_end
,validation_epoch_end
, andtest_epoch_end
in v1.5 - Remove these hooks entirely, and their corresponding calls in the loops in v1.7
This is how easily users can implement this in their LightningModule with the existing hooks:
class MyModel(LightningModule):
def __init__(self):
self._train_outputs = [] # <----- New
def training_step(self, *args, **kwargs):
...
output = ...
self._train_outputs.append(output) # <----- New
return output
def on_train_epoch_end(self) -> None:
# process self._train_outputs
self._train_outputs = [] # <----- New
so we're talking about 3 lines of code here per train/val/test/predict. I argue this is so minimal compared to the amount of logic that usually goes into post-processing the outputs anyways.
@PyTorchLightning/core-contributors
Originally posted by @ananthsub in #8690