|
15 | 15 |
|
16 | 16 | from lightning_utilities.core.imports import RequirementCache |
17 | 17 | from torch import Tensor |
18 | | -from torch.optim import LBFGS, Optimizer |
19 | 18 |
|
20 | 19 | from lightning_lite.plugins.precision.precision import Precision |
21 | 20 | from lightning_lite.utilities.enums import AMPType, PrecisionType |
22 | 21 | from lightning_lite.utilities.imports import _APEX_AVAILABLE |
| 22 | +from lightning_lite.utilities.types import Steppable |
23 | 23 |
|
24 | 24 | _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") |
25 | 25 | if TYPE_CHECKING and _DEEPSPEED_AVAILABLE: |
@@ -65,21 +65,14 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona |
65 | 65 | self.amp_type = amp_type |
66 | 66 | self.amp_level = amp_level |
67 | 67 |
|
68 | | - def backward(self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None: |
| 68 | + def backward(self, tensor: Tensor, model: "deepspeed.DeepSpeedEngine", *args: Any, **kwargs: Any) -> None: |
69 | 69 | """Performs back-propagation using DeepSpeed's engine.""" |
70 | | - if model is None: |
71 | | - raise ValueError("Please provide the model as input to `backward`.") |
72 | 70 | model.backward(tensor, *args, **kwargs) |
73 | 71 |
|
74 | 72 | def optimizer_step( |
75 | 73 | self, |
76 | | - optimizer: Optimizer, |
77 | | - model: Optional["deepspeed.DeepSpeedEngine"] = None, |
| 74 | + optimizer: Steppable, |
78 | 75 | **kwargs: Any, |
79 | 76 | ) -> Any: |
80 | | - if isinstance(optimizer, LBFGS): |
81 | | - raise TypeError("DeepSpeed and the LBFGS optimizer are not compatible.") |
82 | | - if model is None: |
83 | | - raise TypeError("`optimizer_step()` requires a reference to the model.") |
84 | 77 | # DeepSpeed handles the optimizer step internally |
85 | | - return model.step(**kwargs) |
| 78 | + return optimizer.step(**kwargs) |
0 commit comments