Skip to content

Lazy initialize properties in Strategies #11097

@four4fish

Description

@four4fish

Proposed refactor

Strategy has four properties in init : Accelerator, Precision_plugin, Checkpoint_io and Cluster_enviroment.

User could pass in a strategy class or a strategy str into trainer.

  • If user pass a string, accelerator_connector will initialize the strategy with the properties selection.
  • If user pass a strategy class, the logics get a bit mess. We should track what user has passed in V.S strategy default value V.S. Trainer.accelerator_connector

This issue propose lazy initialization for strategy properties. @awaelchli and @ananthsub also bought this up before

Motivation

For correctness, maintenance and enable future simplifications

Pitch

Current Training_type_plugin init will set default value to precision_plugin and checkpoint_io:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/training_type_plugin.py#L40-L58

    def __init__(
        self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
    ) -> None:
        self._model: Optional[Module] = None
        checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
        self._checkpoint_io = checkpoint_io
        self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin()

Proposal

def __init__(
        self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
    ) -> None:
        self._model: Optional[Module] = None
        self._checkpoint_io = checkpoint_io
        self._precision_plugin = precision_plugin

@property
def checkpoint_io(self):
       return self._checkpoint_io if self._checkpoint_io is not None else TorchCheckpointIO()

@setter
def checkpoint_io(checkpoint_io):
       self._checkpoint_io = checkpoint_io

@property
def precision_plugin(self):
      return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin()

@setter
def precision_plugin(precision_plugin):
       self._precision_plugin = precision_plugin

In accelerator_connector.py, we know:
training_type_plugin._precision_plugin before we call setter is what user passed in, not the default value.

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @akihironitta @kaushikb11 @ananthsub

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions