Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def setup(
# Let accelerator/plugin wrap and connect the models and optimizers
model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))
model = _LiteModule(model, self._precision_plugin)
optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers]
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
self._models_setup += 1
if optimizers:
# join both types in a list for API convenience
Expand Down
15 changes: 6 additions & 9 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins import PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device


Expand All @@ -30,31 +29,29 @@ def _do_nothing_closure() -> None:


class _LiteOptimizer:
def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
def __init__(self, optimizer: Optimizer, strategy: TrainingTypePlugin) -> None:
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the accelerator/strategy plugin.
step calls to the strategy plugin.

The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.

Args:
optimizer: The optimizer to wrap
accelerator: Reference to the accelerator for handling the optimizer step
strategy: Reference to the strategy for handling the optimizer step
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_LiteOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._accelerator = accelerator
# TODO (@awaelchli) refactor to take Strategy as param
self._strategy = self._accelerator.training_type_plugin
self._strategy = strategy

@property
def optimizer(self) -> Optimizer:
return self._optimizer

def state_dict(self) -> Dict[str, Tensor]:
return self._accelerator.optimizer_state(self.optimizer)
return self._strategy.optimizer_state(self.optimizer)

def step(self, closure: Optional[Callable] = None) -> None:
closure = closure or _do_nothing_closure
Expand Down
15 changes: 6 additions & 9 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
Expand Down Expand Up @@ -144,21 +143,19 @@ def test_lite_optimizer_wraps():


def test_lite_optimizer_state_dict():
"""Test that the LiteOptimizer calls into the accelerator/strategy to collect the state."""
"""Test that the LiteOptimizer calls into the strategy to collect the state."""
optimizer = Mock()
accelerator = Mock()
lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator)
strategy = Mock()
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
lite_optimizer.state_dict()
accelerator.optimizer_state.assert_called_with(optimizer)
strategy.optimizer_state.assert_called_with(optimizer)


def test_lite_optimizer_steps():
"""Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
optimizer = Mock()
strategy = Mock()
accelerator = Accelerator(None, strategy)
lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator)
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
lite_optimizer.step()
strategy = accelerator.training_type_plugin
strategy.optimizer_step.assert_called_once()
strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model)
strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model)