-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
📚 Documentation
After investigating this community issue #10379, it seems the parameter move_metrics_to_cpu isn't working as expected with ddp.
Bug reason: The ResultCollection object storing the metrics is moved on the CPU here and on epoch end when performing compute, the tensors are being reduced and this would raise RuntimeError: Tensors must be CUDA and dense.
Note, this should also fail with sync_dist argument within the self.log method.
Background:
Before: Logging prior to the ResultCollection object, tensor metrics used to have an o(num_batches_in_epoch) memory footprint as there were stored within a list and reduced on epoch ended. When the epoch had a very large number of batches, this would raise an OOM.
Now: The ResultCollection has a memory space of o(1) for tensor metrics therefore, the move_metrics_to_cpu argument isn't as impactful as before.
As believe there is 2 option forward:
Option 1 🚀:
Make the ResultCollection of move_metrics_to_cpu. The ResultCollection would be responsible to move back and forth the ResultMetric on device before distributed collection and on CPU right after.
Pros:
- This might be impactful for large Metric from TM but a user could alternatively do this manually within its LightningModule
Cons:
- Engineering heavy
- Perform drop.
Here is the pseudo-code for such solution. This would be spread across several code parts.
class ResultCollection
def __init__(self, ..., move_metrics_to_cpu):
self.move_metrics_to_cpu = move_metrics_to_cpu
def metrics(self, on_step = False):
if on_step
for result_metric in self.result_metrics:
# move the metric back the device
if self.move_metrics_to_cpu:
result_metric.to(self.device)
... = result_metric.compute() # perform distributed reduction
# move the metric back to cpu
if self.move_metrics_to_cpu:
result_metric.to("cpu")Option 2 😋:
Depreciate and remove the section of code moving the ResultCollection to CPU.
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.