Skip to content

[RFC] Mechanism to dereference automatically in the Loops #9385

@carmocca

Description

@carmocca

Proposed refactoring or deprecation

The Loops run design

https://github.com/PyTorchLightning/pytorch-lightning/blob/41ba639859cf6c6bf319eb33e5b3394504315962/pytorch_lightning/loops/base.py#L94-L120

doesn't include a mechanism to share data between hooks during the same iteration or between iterations. This forces to write everything to self which is flexible but opens the door to easily forgetting to clear state as the variables will not get garbage collected.

Motivation

#9386 is the perfect example as it shows that self.dataloader_iter was added but the reference was not freed. It gets defined in on_run_start but we only use it in a later hook.

This pattern is also seen in other places:

pytorch_lightning/loops/optimizer/optimizer_loop.py:89:        outputs, self.outputs = self.outputs, []  # free memory
pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:136:        # free memory
pytorch_lightning/loops/epoch/prediction_epoch_loop.py:104:        # free memory
pytorch_lightning/loops/epoch/training_epoch_loop.py:222:        # free memory
pytorch_lightning/loops/batch/manual.py:82:        output, self._output = self._output, None  # free memory
pytorch_lightning/loops/batch/training_batch_loop.py:94:        self.batch_outputs = None  # free memory
pytorch_lightning/loops/dataloader/evaluation_loop.py:120:        # free memory

Pitch

Automaticaly dereference data at the end of run.

Option 1:

shm: Optional[Any] = None  # shared memory that should get deferenced after `run`
shm = self.on_run_start(*args, shm=shm, **kwargs)

while not self.done:
    try:
        shm = self.on_advance_start(*args, shm=shm **kwargs)
        shm = self.advance(*args, shm=shm, **kwargs)
        shm = self.on_advance_end(shm=shm)
        self.restarting = False
    except StopIteration:
        break

output = self.on_run_end(shm=shm)
return output

Option 2:

class Loop(ABC):
    def __init__(self):
        self.shm = object()

    def run():
        ...
        self.on_run_start(*args, **kwargs)

        while not self.done:
            try:
                self.on_advance_start(*args, **kwargs)
                self.advance(*args, **kwargs)
                self.on_advance_end()
                self.restarting = False
            except StopIteration:
                break

        output = self.on_run_end()
        self.shm = object()  # free memory
        return output

where loop writers save the temporal state with self.shm.dataloader_iter = ...


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @tchaton @justusschock @awaelchli @carmocca @ananthsub @ninginthecloud @rohitgr7 @akihironitta

Metadata

Metadata

Assignees

No one assigned

    Labels

    designIncludes a design discussionfeatureIs an improvement or enhancementlet's do it!approved to implementloopsRelated to the Loop API

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions