-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
This RFC expresses a desire for a simpler mechanism to send metric data from the LightningModule and Callback to the Trainer.
Motivation
The current approach with LightningModule.log has many downsides:
-
Poor debugging experience: we started seeing this sort of failure ([RFC] Depreceate the move_metrics_to_cpu Trainer argument. #10595) after Save the loop progress state by default #10784 . This is the example stacktrace: https://gist.github.com/ananthsub/45c154145d0f852503c6a547f59e91f0 . It is very hard to tell where in logging did I go wrong. We see this error even after updating our torchmetrics dependency.
-
LightningModule.log conflates too many things:
- Handles synchronization of tensors, allowing user-provided syncing functions
- Handles aggregation of tensors across steps and for the end of the epoch
- Selectively samples what data to log based on
log_every_n_stepsand the global step - Makes the trainer aware of which metrics to reset at the end of the epoch
- Also handles partitioning keys based on the dataloader idx??
- Supports passing the batch_size for weighted averages to take into account. if the batch_size isn't passed, Lightning silently makes its best attempt to figure this out (not foolproof!)
Many of these assumptions came from the original Result object. This was a class that preceded the whole torchmetrics project.
The log API is not straightforward to use given the large number of options available, and differing implementation differences when logging floats/tensors vs torchmetric Metric objects.
Pitch
Provide a new API like this:
@dataclass
class PutData:
name: str
val: torch.Tensor
destinations: List[str]
timestamp
step: Optional[int] = None
class PutMixin:
def __init__(self):
self.pl_put_records: List[PutData] = []
def put(self, name: str, val: Union[float, torch.Tensor], destinations=Optional[List[str]] = None) -> None:
destinations = destinations or ["progress_bar", "callbacks", "loggers"]
self.pl_put_records.append(PutData(name=name, val=val, destinations=destinations, timestamp=time.monotonic()))
def put_dict(self, dictionary: Mapping[str, Union[float, torch.Tensor]], destinations=Optional[List[str]] = None) -> None:
for k, v in dictionary:
self.put(k, v, destinations)example calling code:
from torchmetrics import MeanMetric
def __init__(self,...):
self.loss_avg = MeanMetric()
self.metric = ...
self.my_fancy_metric = ...
def training_step(self, batch, batch_idx):
loss = compute_loss(batch)
self.loss_avg.update(loss, batch_size(batch)) # no more guess work from the trainer
metric.update(batch, loss)
self.put("loss", loss_avg.compute()) # the user always passes tensors instead of a mix of tensors and Metric instances
self.put("metric", metric.compute(), destinations=("callbacks") ) # only send this info to callbacks, not loggers
with self.my_fancy_metric.sync_context():
self.put("fancy_metric", self.my_fancy_metric.compute())
return loss
def on_train_epoch_end(self):
self.loss_avg.reset()
self.metric.reset()
self.my_fancy_metric.reset()The trainer already calls all of the hooks offered by the LightningModule & Callback APIs. We have the logic in the trainer here that can inspect the data and reset it after every hook is called: https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebd7df22acc6e0de4569edacd0ec8319ab4be21/pytorch_lightning/trainer/trainer.py#L1522-L1587. Which means data can be taken from here, attached with the global_step or other metadata the trainer is aware of, and routed to the relevant destinations (callbacks/loggers/metrics).
Pros:
- Simpler API for users with fewer side-effects to consider, especially around checkpointing & syncing states
- Simpler API means it's easier to onboard for new users. No need to get familiar with different arguments for
sync_dist,on_step/on_epoch,rank_zero_onlyormetric_attributeamongst others - Simpler implementation that backs this. This is critical for users to debug failures. A secondary benefit is it's easier for developers to maintain the framework over time.
- No need to duplicate logic between torchmetrics and Lightning. All of the metric syncing logic is delegated to torchmetrics in user land, which solves this more elegantly
- Users should be able to call publish these metrics anytime. There should be no restrictions around when data is stored, unlike
logtoday. - We don't need to store the current_fx name on the module before calling each lightningmodule hook, just so that we could set the defaults for on_step and on_epoch inside of
log - The user-facing API can be wrapped in a small mixin to share between the LightningModule & Callback. Then the Trainer doesn't need to dynamically patch the LightningModule's log onto the Callback anymore: https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebd7df22acc6e0de4569edacd0ec8319ab4be21/pytorch_lightning/trainer/connectors/callback_connector.py#L258-L262
- With loss explicitly tracked as a metric from the user side, the user can precisely specify the batch size to compute the correct weighted average. There is no need for Lightning to try to guess the batch size from the black-box batch object, which could silently fail.
- A new name like
putmakes clear that it's separate fromlogginglike Python logging and Lightning's own Loggers. This is generally a means through which the user passes data to the Trainer for usage in other places like callbacks, progress bar, or loggers.
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @Borda @tchaton @justusschock @awaelchli @carmocca @edward-io @ananthsub @rohitgr7 @kamil-kaczmarek @Raalsky @Blaizzy