Skip to content

[RFC] Deprecate weights_summary off the Trainer constructor #9043

@ananthsub

Description

@ananthsub

Proposed refactoring or deprecation

Motivation

We are auditing the Lightning components and APIs to assess opportunities for improvements:

This is a followup to #8478 and #9006

Why do we want to remove this from the core trainer logic?

  • We need a way for users to customize more of the inputs to the model summary over time without affecting the trainer API. Today, changes to the model summary API also require changes in the core trainer (e.g. the addition of max_depth ). This gives model summarization more room to grow without cascading changes elsewhere.
  • Users may want to configure this summarization for different points of execution. Right now, this is hardcoded to be run only during fit(). But users could want to call this potentially multiple times during each of trainer.fit(), trainer.validate(), trainer.test() or trainer.predict().
  • Users may want to customize where they save the summary. Right now, it's printed to stdout, but this could also be useful to save to a file or upload to another service for tracking the run.
  • The current implementation runs on global rank 0 only in order to avoid printing out multiple summary tables. However, running this on rank 0 will break for model parallel use cases that require communication across ranks. This can lead to subtle failures if example_input_array is set as a property on the LightningModule. For instance, a model wrapped with FSDP will break because parameters need to be all-gathered across layers across ranks.
  • In case the LightningModule leverages PyTorch LazyModules, users may want to generate this summary only after the first batch is processed in order to get accurate parameter estimations. Estimates of parameter sizes with lazy modules would be misleading.
  • AFAICT, this is the only piece of logic that runs in between on_pretrain_routine_start/end hooks. Would we still need these hooks if the summarization logic was removed from the trainer? Why doesn't this happen in on_train_start today? We don't have on_prevalidation_routine_start/end hooks: the necessity of these hooks for training isn't clear to me, and further deprecating these hooks could bring greater API clarity & simplification.
    https://github.com/PyTorchLightning/pytorch-lightning/blob/8a931732ae5135e3e55d9c7b7031d81837e5798a/pytorch_lightning/trainer/trainer.py#L1103-L1113

Pitch

A callback in Lightning naturally fits this extension purpose. It generalizes well across lightning modules, has great flexibility for when it can be called, and allows users to customize the summarization logic (e.g. integrate other libraries more easily).

Additional context

The model summary is by default enabled right now. This is likely the core issue we have to resolve as to whether this is opt-in or opt-out: #8478 (comment)

Seeking @edenafek @tchaton 's input on this


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions