Skip to content

[RFC] Simpler mechanism to publish data from LightningModule and Callback to Trainer #11715

@ananthsub

Description

@ananthsub

🚀 Feature

This RFC expresses a desire for a simpler mechanism to send metric data from the LightningModule and Callback to the Trainer.

Motivation

The current approach with LightningModule.log has many downsides:

  1. Poor debugging experience: we started seeing this sort of failure ([RFC] Depreceate the move_metrics_to_cpu Trainer argument. #10595) after Save the loop progress state by default #10784 . This is the example stacktrace: https://gist.github.com/ananthsub/45c154145d0f852503c6a547f59e91f0 . It is very hard to tell where in logging did I go wrong. We see this error even after updating our torchmetrics dependency.

  2. LightningModule.log conflates too many things:

  • Handles synchronization of tensors, allowing user-provided syncing functions
  • Handles aggregation of tensors across steps and for the end of the epoch
  • Selectively samples what data to log based on log_every_n_steps and the global step
  • Makes the trainer aware of which metrics to reset at the end of the epoch
  • Also handles partitioning keys based on the dataloader idx??
  • Supports passing the batch_size for weighted averages to take into account. if the batch_size isn't passed, Lightning silently makes its best attempt to figure this out (not foolproof!)

Many of these assumptions came from the original Result object. This was a class that preceded the whole torchmetrics project.

The log API is not straightforward to use given the large number of options available, and differing implementation differences when logging floats/tensors vs torchmetric Metric objects.

Pitch

Provide a new API like this:

@dataclass
class PutData:
    name: str
    val: torch.Tensor
    destinations: List[str]
    timestamp 
    step: Optional[int] = None

class PutMixin:
    def __init__(self):
        self.pl_put_records: List[PutData] = []

    def put(self, name: str, val: Union[float, torch.Tensor], destinations=Optional[List[str]] = None) -> None:
        destinations = destinations or ["progress_bar", "callbacks", "loggers"]
        self.pl_put_records.append(PutData(name=name, val=val, destinations=destinations, timestamp=time.monotonic()))
    def put_dict(self, dictionary: Mapping[str, Union[float, torch.Tensor]], destinations=Optional[List[str]] = None) -> None:
        for k, v in dictionary:
            self.put(k, v, destinations)

example calling code:

from torchmetrics import MeanMetric

def __init__(self,...):
    self.loss_avg = MeanMetric()
    self.metric = ...
    self.my_fancy_metric = ...
    
def training_step(self, batch, batch_idx):
    loss = compute_loss(batch)
    self.loss_avg.update(loss, batch_size(batch)) # no more guess work from the trainer
    metric.update(batch, loss)
    self.put("loss", loss_avg.compute()) # the user always passes tensors instead of a mix of tensors and Metric instances
    self.put("metric", metric.compute(), destinations=("callbacks") ) # only send this info to callbacks, not loggers
    with self.my_fancy_metric.sync_context():
        self.put("fancy_metric", self.my_fancy_metric.compute())
    return loss

def on_train_epoch_end(self):
    self.loss_avg.reset()
    self.metric.reset()
    self.my_fancy_metric.reset()

The trainer already calls all of the hooks offered by the LightningModule & Callback APIs. We have the logic in the trainer here that can inspect the data and reset it after every hook is called: https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebd7df22acc6e0de4569edacd0ec8319ab4be21/pytorch_lightning/trainer/trainer.py#L1522-L1587. Which means data can be taken from here, attached with the global_step or other metadata the trainer is aware of, and routed to the relevant destinations (callbacks/loggers/metrics).

Pros:

  • Simpler API for users with fewer side-effects to consider, especially around checkpointing & syncing states
  • Simpler API means it's easier to onboard for new users. No need to get familiar with different arguments for sync_dist, on_step/on_epoch, rank_zero_only or metric_attribute amongst others
  • Simpler implementation that backs this. This is critical for users to debug failures. A secondary benefit is it's easier for developers to maintain the framework over time.
  • No need to duplicate logic between torchmetrics and Lightning. All of the metric syncing logic is delegated to torchmetrics in user land, which solves this more elegantly
  • Users should be able to call publish these metrics anytime. There should be no restrictions around when data is stored, unlike log today.
  • We don't need to store the current_fx name on the module before calling each lightningmodule hook, just so that we could set the defaults for on_step and on_epoch inside of log
  • The user-facing API can be wrapped in a small mixin to share between the LightningModule & Callback. Then the Trainer doesn't need to dynamically patch the LightningModule's log onto the Callback anymore: https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebd7df22acc6e0de4569edacd0ec8319ab4be21/pytorch_lightning/trainer/connectors/callback_connector.py#L258-L262
  • With loss explicitly tracked as a metric from the user side, the user can precisely specify the batch size to compute the correct weighted average. There is no need for Lightning to try to guess the batch size from the black-box batch object, which could silently fail.
  • A new name like put makes clear that it's separate from logging like Python logging and Lightning's own Loggers. This is generally a means through which the user passes data to the Trainer for usage in other places like callbacks, progress bar, or loggers.

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @tchaton @justusschock @awaelchli @carmocca @edward-io @ananthsub @rohitgr7 @kamil-kaczmarek @Raalsky @Blaizzy

Metadata

Metadata

Assignees

Labels

designIncludes a design discussionfeatureIs an improvement or enhancementloggingRelated to the `LoggerConnector` and `log()`

Type

No type

Projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions