-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactoring or deprecation
The Loops run design
doesn't include a mechanism to share data between hooks during the same iteration or between iterations. This forces to write everything to self which is flexible but opens the door to easily forgetting to clear state as the variables will not get garbage collected.
Motivation
#9386 is the perfect example as it shows that self.dataloader_iter was added but the reference was not freed. It gets defined in on_run_start but we only use it in a later hook.
This pattern is also seen in other places:
pytorch_lightning/loops/optimizer/optimizer_loop.py:89: outputs, self.outputs = self.outputs, [] # free memory
pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:136: # free memory
pytorch_lightning/loops/epoch/prediction_epoch_loop.py:104: # free memory
pytorch_lightning/loops/epoch/training_epoch_loop.py:222: # free memory
pytorch_lightning/loops/batch/manual.py:82: output, self._output = self._output, None # free memory
pytorch_lightning/loops/batch/training_batch_loop.py:94: self.batch_outputs = None # free memory
pytorch_lightning/loops/dataloader/evaluation_loop.py:120: # free memoryPitch
Automaticaly dereference data at the end of run.
Option 1:
shm: Optional[Any] = None # shared memory that should get deferenced after `run`
shm = self.on_run_start(*args, shm=shm, **kwargs)
while not self.done:
try:
shm = self.on_advance_start(*args, shm=shm **kwargs)
shm = self.advance(*args, shm=shm, **kwargs)
shm = self.on_advance_end(shm=shm)
self.restarting = False
except StopIteration:
break
output = self.on_run_end(shm=shm)
return outputOption 2:
class Loop(ABC):
def __init__(self):
self.shm = object()
def run():
...
self.on_run_start(*args, **kwargs)
while not self.done:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self.restarting = False
except StopIteration:
break
output = self.on_run_end()
self.shm = object() # free memory
return outputwhere loop writers save the temporal state with self.shm.dataloader_iter = ...
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @Borda @tchaton @justusschock @awaelchli @carmocca @ananthsub @ninginthecloud @rohitgr7 @akihironitta