-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Background
We are auditing the Lightning components and APIs to assess opportunities for improvements:
- Review Lightning architecture & API #7740
- https://docs.google.com/document/d/1xHU7-iQSpp9KJTjI3As2EM0mfNHHr37WZYpDpwLkivA/edit#
One item that came up was summarize() defined on the LightningModule. This does not need to be on the core API.
Objective
Instead, we'd like to abstract this out into a separate utility function which accepts a LightningModule and prints out the model summary
Motivation
- Simplify the LightningModule
- Simplify the core Trainer logic
- Offer summarization as an opt-in utility
- Demonstrate how projects can integrate with other libraries for summarization extensions (e.g. torchinfo, fvcore's flop_count
Pitch
- Create a new utility function for the summarization that accepts a LightningModule:
def summarize(module: LightningModule) -> ModelSummary. The implementation should be nearly identical to the existing method on the LightningModule. - Add this utility under
pytorch_lightning/utilities/model_summary.py: https://github.com/PyTorchLightning/pytorch-lightning/tree/master/pytorch_lightning/utilities . - Moreover, we can move the layer and model summary classes here too since they're not really part of the core Lightning API: https://github.com/PyTorchLightning/pytorch-lightning/blob/6604fc1344e1b8a459c45a5a2157aa7fc60d950d/pytorch_lightning/core/memory.py#L37-L325
- Replace the implementation of
LightningModule.summarize()with this new utility function - Mark
LightningModule.summarize()as deprecated in v1.5 and slated for removal in v1.7 - Replace the Trainer's call to the LightningModule.summarize() by directly calling the new utility function: https://github.com/PyTorchLightning/pytorch-lightning/blob/6604fc1344e1b8a459c45a5a2157aa7fc60d950d/pytorch_lightning/trainer/trainer.py#L1000-L1003
Extensions
Define a ModelSummary Callback which calls this utility function. A callback in Lightning naturally fits this extension purpose. It generalizes well across lightning modules, has great flexibility for when it can be called, and allows users to customize the summarization logic (e.g. integrate other libraries like torchsummary more easily).
- With this callback available, this logic can be removed from the core Trainer in order to be more pluggable: https://github.com/PyTorchLightning/pytorch-lightning/blob/6604fc1344e1b8a459c45a5a2157aa7fc60d950d/pytorch_lightning/trainer/trainer.py#L1000-L1004
Why do we want to remove this from the core trainer logic?
- The current implementation runs on global rank 0 only in order to avoid printing out multiple summary tables. However, running this on rank 0 will break for model parallel use cases that require communication across ranks. This can lead to subtle failures if
example_input_arrayis set as a property on the LightningModule. For instance, a model wrapped with FSDP will break because parameters need to be all-gathered across layers across ranks. - Users may want to configure this summarization for different points of execution. For instance, calling this at the start of
trainer.fit(),trainer.validate(),trainer.test()ortrainer.predict(). Right now, this is hardcoded to be run only duringfit(). - In case the LightningModule leverages PyTorch LazyModules, users may want to generate this summary only after the first batch is processed in order to get accurate parameter estimations.
- Users may want to customize where they save the summary. Right now, it's printed to stdout, but this could also be useful to save to a file or upload to another service for tracking the run.
Offering this as a callback allows us to deprecate weights_summary from the Trainer constructor. This benefits the project by reducing the number of custom constructor args the trainer accepts in favor of a generic one (via the callbacks argument)
Another extension:
- Generalize the utility function for summarization to accept any nn.Module rather than relying on a LightningModule. This can be useful to summarize arbitrary nn.Modules instead of the entire module.
Alternatives
Keep as is