-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactoring or deprecation
Change the format in the outputs that the training_epoch_end
and training_step_end
hooks receive. This is a breaking change and requires careful deprecation. It will affect users with either multiple optimizers or truncated backprop IF they also implement the aforementioned hooks.
Motivation
When using multiple optimizers or truncated backprop (or both) the inputs passed to the training_epoch_end
hook are the outputs from the training step arranged in a 2D or 3D nested list of lists. The shape of this multi-dim array is
(num_optimizers, num_batches, num_tbptt_splits)
in the general case and when num_optimizers=1 or truncated backprop is deactivated, the dimensions get squeezed. The problem with this is that the order of these dimensions does not correspond to the loop structure:
for batch in dataloader:
for split in batch:
for opt in optimizers:
...
It means this output format will never generalize for loop customization as the ordering is arbitrary.
Currently, this permutation of dimensions is hard-coded and will break for custom loops.
Pitch
Deprecate the current format and make it consistent with the loop structure, meaning, we adopt the format
(num_batches, num_tbptt_splits, num_optimizers)
. This corresponds 1:1 with the loop structure. The standardization here will unblock custom loops with arbitrary nesting and output aggregation with less effort.
Proposed deprecation plan:
-
In 1.5, log a message that the format will change in the future (if using multiple optimizers and hook is overridden)
-
In 1.5, the user will change their code given our recommendation and will signal this by adding a new argument to the hook:
def training_step_end(self, outputs, new_format=True): ... def training_epoch_end(self, outputs, new_format=True): ...
This will trigger the loop to call the hook with the new format for
outputs
instead of the old one. -
In 1.7, the new format will be used unconditionally and be a breaking change if users did not adapt their code until now. The argument
new_format=True/False
will become ineffective and can be removed again.
Note: The only purpose the new_format
argument serves is for inspection by our loop to infer what the user expects to get. We will not pass a value so the user must make it a keyword argument.
Alternative deprecation plan:
Instead of a new_format
argument in the signature, one can also add properties to the LightningModule:
class MyLightningModule(LightningModule):
def __init__(self):
self.v1_7_training_epoch_end_format = True
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.