Skip to content

Commit abc805f

Browse files
carmoccaawaelchli
andauthored
Remove the model argument from Lite's optimizer_step via structural typing (#14810)
Co-authored-by: awaelchli <[email protected]>
1 parent dbb4482 commit abc805f

File tree

10 files changed

+60
-45
lines changed

10 files changed

+60
-45
lines changed

src/lightning_lite/plugins/precision/deepspeed.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
from lightning_utilities.core.imports import RequirementCache
1717
from torch import Tensor
18-
from torch.optim import LBFGS, Optimizer
1918

2019
from lightning_lite.plugins.precision.precision import Precision
2120
from lightning_lite.utilities.enums import AMPType, PrecisionType
2221
from lightning_lite.utilities.imports import _APEX_AVAILABLE
22+
from lightning_lite.utilities.types import Steppable
2323

2424
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
2525
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
@@ -65,21 +65,14 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
6565
self.amp_type = amp_type
6666
self.amp_level = amp_level
6767

68-
def backward(self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
68+
def backward(self, tensor: Tensor, model: "deepspeed.DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
6969
"""Performs back-propagation using DeepSpeed's engine."""
70-
if model is None:
71-
raise ValueError("Please provide the model as input to `backward`.")
7270
model.backward(tensor, *args, **kwargs)
7371

7472
def optimizer_step(
7573
self,
76-
optimizer: Optimizer,
77-
model: Optional["deepspeed.DeepSpeedEngine"] = None,
74+
optimizer: Steppable,
7875
**kwargs: Any,
7976
) -> Any:
80-
if isinstance(optimizer, LBFGS):
81-
raise TypeError("DeepSpeed and the LBFGS optimizer are not compatible.")
82-
if model is None:
83-
raise TypeError("`optimizer_step()` requires a reference to the model.")
8477
# DeepSpeed handles the optimizer step internally
85-
return model.step(**kwargs)
78+
return optimizer.step(**kwargs)

src/lightning_lite/plugins/precision/native_amp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
import torch
1818
from torch import Tensor
1919
from torch.nn import Module
20-
from torch.optim import LBFGS, Optimizer
20+
from torch.optim import LBFGS
2121

2222
from lightning_lite.plugins.precision import Precision
2323
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
24+
from lightning_lite.utilities.types import Steppable
2425

2526
if _TORCH_GREATER_EQUAL_1_10:
2627
from torch import autocast as new_autocast
@@ -63,13 +64,12 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs
6364

6465
def optimizer_step(
6566
self,
66-
optimizer: Optimizer,
67-
model: Optional[Module] = None,
67+
optimizer: Steppable,
6868
**kwargs: Any,
6969
) -> Any:
7070
if self.scaler is None:
7171
# skip scaler logic, as bfloat16 does not require scaler
72-
return super().optimizer_step(optimizer, model=model, **kwargs)
72+
return super().optimizer_step(optimizer, **kwargs)
7373
if isinstance(optimizer, LBFGS):
7474
raise TypeError("Native AMP and the LBFGS optimizer are not compatible.")
7575
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found

src/lightning_lite/plugins/precision/precision.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.nn import Module
1919
from torch.optim import Optimizer
2020

21-
from lightning_lite.utilities.types import _PARAMETERS
21+
from lightning_lite.utilities.types import _PARAMETERS, Steppable
2222

2323

2424
class Precision:
@@ -61,8 +61,7 @@ def post_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
6161

6262
def optimizer_step(
6363
self,
64-
optimizer: Optimizer,
65-
model: Optional[Module] = None,
64+
optimizer: Steppable,
6665
**kwargs: Any,
6766
) -> Any:
6867
"""Hook to run the optimizer step."""

src/lightning_lite/plugins/precision/tpu.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional
15-
16-
from torch.nn import Module
17-
from torch.optim import Optimizer
14+
from typing import Any
1815

1916
from lightning_lite.plugins.precision.precision import Precision
17+
from lightning_lite.utilities.types import Steppable
2018

2119

2220
class TPUPrecision(Precision):
2321
"""Precision plugin for TPU integration."""
2422

2523
def optimizer_step(
2624
self,
27-
optimizer: Optimizer,
28-
model: Optional[Module] = None,
25+
optimizer: Steppable,
2926
**kwargs: Any,
3027
) -> Any:
3128

src/lightning_lite/strategies/strategy.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from lightning_lite.utilities.apply_func import move_data_to_device
3131
from lightning_lite.utilities.distributed import ReduceOp
3232
from lightning_lite.utilities.optimizer import optimizer_to_device
33-
from lightning_lite.utilities.types import _PATH
33+
from lightning_lite.utilities.types import _PATH, Steppable
3434

3535
TBroadcast = TypeVar("TBroadcast")
3636
TReduce = TypeVar("TReduce")
@@ -167,18 +167,16 @@ def backward(self, tensor: Tensor, module: Optional[Module], *args: Any, **kwarg
167167

168168
def optimizer_step(
169169
self,
170-
optimizer: Optimizer,
171-
model: Optional[Module] = None,
170+
optimizer: Steppable,
172171
**kwargs: Any,
173172
) -> Any:
174173
"""Performs the actual optimizer step.
175174
176175
Args:
177176
optimizer: the optimizer performing the step
178-
model: reference to the model, optionally defining optimizer step related hooks
179177
**kwargs: Any extra arguments to ``optimizer.step``
180178
"""
181-
return self.precision_plugin.optimizer_step(optimizer, model=model, **kwargs)
179+
return self.precision_plugin.optimizer_step(optimizer, **kwargs)
182180

183181
@abstractmethod
184182
def reduce(

src/lightning_lite/utilities/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,12 @@ def __init__(
7777

7878
def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None:
7979
...
80+
81+
82+
@runtime_checkable
83+
class Steppable(Protocol):
84+
"""To structurally type ``optimizer.step()``"""
85+
86+
# Inferred from `torch.optim.optimizer.pyi`
87+
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
88+
...

src/lightning_lite/wrappers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightning_lite.strategies import Strategy
2626
from lightning_lite.utilities import move_data_to_device
2727
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
28+
from lightning_lite.utilities.types import Steppable
2829

2930
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
3031

@@ -56,9 +57,13 @@ def state_dict(self) -> Dict[str, Tensor]:
5657

5758
def step(self, closure: Optional[Callable] = None) -> Any:
5859
kwargs = dict(closure=closure) if closure is not None else {}
60+
if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Steppable):
61+
# only DeepSpeed defines this
62+
optimizer = self._strategy.model
63+
else:
64+
optimizer = self.optimizer
5965
return self._strategy.optimizer_step(
60-
self.optimizer,
61-
model=getattr(self._strategy, "model", None),
66+
optimizer,
6267
**kwargs,
6368
)
6469

tests/tests_lite/plugins/precision/test_deepspeed.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
from unittest.mock import Mock
1616

1717
import pytest
18+
from tests_lite.helpers.runif import RunIf
1819

1920
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
21+
from lightning_lite.utilities.types import Steppable
2022

2123

2224
def test_invalid_precision_with_deepspeed_precision():
@@ -47,10 +49,20 @@ def test_deepspeed_precision_backward():
4749
model.backward.assert_called_once_with(tensor, "positional-arg", keyword="arg")
4850

4951

52+
@RunIf(deepspeed=True)
53+
def test_deepspeed_engine_is_steppable():
54+
"""Test that the ``DeepSpeedEngine`` conforms to the Steppable API.
55+
56+
If this fails, then optimization will be broken for DeepSpeed.
57+
"""
58+
from deepspeed import DeepSpeedEngine
59+
60+
engine = DeepSpeedEngine(Mock(), Mock())
61+
assert isinstance(engine, Steppable)
62+
63+
5064
def test_deepspeed_precision_optimizer_step():
5165
precision_plugin = DeepSpeedPrecision(precision=32, amp_type="native")
52-
optimizer = Mock()
53-
model = Mock()
54-
precision_plugin.optimizer_step(optimizer, model=model, lr_kwargs=dict())
66+
optimizer = model = Mock()
67+
precision_plugin.optimizer_step(optimizer, lr_kwargs=dict())
5568
model.step.assert_called_once_with(lr_kwargs=dict())
56-
optimizer.step.assert_not_called()

tests/tests_lite/plugins/precision/test_native_amp.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ def test_native_amp_precision_optimizer_step_with_scaler():
6464
precision_plugin = NativeMixedPrecision(precision="mixed", device="cuda")
6565
precision_plugin.scaler = Mock()
6666
optimizer = Mock()
67-
model = Mock()
6867

69-
precision_plugin.optimizer_step(optimizer, model=model, keyword="arg")
68+
precision_plugin.optimizer_step(optimizer, keyword="arg")
7069
precision_plugin.scaler.step.assert_called_once_with(optimizer, keyword="arg")
7170
precision_plugin.scaler.update.assert_called_once()
7271

@@ -76,7 +75,6 @@ def test_native_amp_precision_optimizer_step_without_scaler():
7675
precision_plugin = NativeMixedPrecision(precision="bf16", device="cuda")
7776
assert precision_plugin.scaler is None
7877
optimizer = Mock()
79-
model = Mock()
8078

81-
precision_plugin.optimizer_step(optimizer, model=model, keyword="arg")
79+
precision_plugin.optimizer_step(optimizer, keyword="arg")
8280
optimizer.step.assert_called_once_with(keyword="arg")

tests/tests_lite/test_wrappers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,18 +251,22 @@ def test_lite_optimizer_state_dict():
251251
def test_lite_optimizer_steps():
252252
"""Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
253253
optimizer = Mock()
254-
strategy = Mock()
254+
strategy = Mock(spec=["optimizer_step"])
255255
strategy.optimizer_step.return_value = 123
256256
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
257257
step_output = lite_optimizer.step()
258258
assert step_output == 123
259-
strategy.optimizer_step.assert_called_once()
260-
strategy.optimizer_step.assert_called_with(optimizer, model=strategy.model)
259+
strategy.optimizer_step.assert_called_once_with(optimizer)
261260

262-
strategy.optimizer_step.reset_mock()
261+
strategy.reset_mock()
263262

264263
# with closure as input
265264
closure = Mock()
266265
lite_optimizer.step(closure=closure)
267-
strategy.optimizer_step.assert_called_once()
268-
strategy.optimizer_step.assert_called_with(optimizer, model=strategy.model, closure=closure)
266+
strategy.optimizer_step.assert_called_once_with(optimizer, closure=closure)
267+
268+
# with model as optimizer
269+
strategy = Mock(spec=["optimizer_step", "model"])
270+
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
271+
lite_optimizer.step()
272+
strategy.optimizer_step.assert_called_once_with(strategy.model)

0 commit comments

Comments
 (0)