-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactor
Strategy has four properties in init : Accelerator, Precision_plugin, Checkpoint_io and Cluster_enviroment.
User could pass in a strategy class or a strategy str into trainer.
- If user pass a string, accelerator_connector will initialize the strategy with the properties selection.
- If user pass a strategy class, the logics get a bit mess. We should track what user has passed in V.S strategy default value V.S. Trainer.accelerator_connector
This issue propose lazy initialization for strategy properties. @awaelchli and @ananthsub also bought this up before
Motivation
For correctness, maintenance and enable future simplifications
Pitch
Current Training_type_plugin init will set default value to precision_plugin and checkpoint_io:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/training_type_plugin.py#L40-L58
def __init__(
self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
) -> None:
self._model: Optional[Module] = None
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
self._checkpoint_io = checkpoint_io
self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin()Proposal
def __init__(
self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
) -> None:
self._model: Optional[Module] = None
self._checkpoint_io = checkpoint_io
self._precision_plugin = precision_plugin
@property
def checkpoint_io(self):
return self._checkpoint_io if self._checkpoint_io is not None else TorchCheckpointIO()
@setter
def checkpoint_io(checkpoint_io):
self._checkpoint_io = checkpoint_io
@property
def precision_plugin(self):
return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin()
@setter
def precision_plugin(precision_plugin):
self._precision_plugin = precision_pluginIn accelerator_connector.py, we know:
training_type_plugin._precision_plugin before we call setter is what user passed in, not the default value.
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @akihironitta @kaushikb11 @ananthsub