Skip to content

Deprecate summarize() off LightningModule #8478

@ananthsub

Description

@ananthsub

🚀 Feature

Background
We are auditing the Lightning components and APIs to assess opportunities for improvements:

One item that came up was summarize() defined on the LightningModule. This does not need to be on the core API.

Objective
Instead, we'd like to abstract this out into a separate utility function which accepts a LightningModule and prints out the model summary

Motivation

  • Simplify the LightningModule
  • Simplify the core Trainer logic
  • Offer summarization as an opt-in utility
  • Demonstrate how projects can integrate with other libraries for summarization extensions (e.g. torchinfo, fvcore's flop_count

Pitch

Extensions

Define a ModelSummary Callback which calls this utility function. A callback in Lightning naturally fits this extension purpose. It generalizes well across lightning modules, has great flexibility for when it can be called, and allows users to customize the summarization logic (e.g. integrate other libraries like torchsummary more easily).

Why do we want to remove this from the core trainer logic?

  • The current implementation runs on global rank 0 only in order to avoid printing out multiple summary tables. However, running this on rank 0 will break for model parallel use cases that require communication across ranks. This can lead to subtle failures if example_input_array is set as a property on the LightningModule. For instance, a model wrapped with FSDP will break because parameters need to be all-gathered across layers across ranks.
  • Users may want to configure this summarization for different points of execution. For instance, calling this at the start of trainer.fit(), trainer.validate(), trainer.test() or trainer.predict(). Right now, this is hardcoded to be run only during fit().
  • In case the LightningModule leverages PyTorch LazyModules, users may want to generate this summary only after the first batch is processed in order to get accurate parameter estimations.
  • Users may want to customize where they save the summary. Right now, it's printed to stdout, but this could also be useful to save to a file or upload to another service for tracking the run.

Offering this as a callback allows us to deprecate weights_summary from the Trainer constructor. This benefits the project by reducing the number of custom constructor args the trainer accepts in favor of a generic one (via the callbacks argument)

Another extension:

  • Generalize the utility function for summarization to accept any nn.Module rather than relying on a LightningModule. This can be useful to summarize arbitrary nn.Modules instead of the entire module.

Alternatives

Keep as is

Additional context

Metadata

Metadata

Assignees

Labels

deprecationIncludes a deprecationdesignIncludes a design discussionfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions