1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from contextlib import contextmanager
15- from typing import Callable , Optional
15+ from typing import Any , Callable , Generator , List , Optional
1616from weakref import proxy
1717
1818from torch .optim import Optimizer
1919
20+ import pytorch_lightning as pl
2021from pytorch_lightning .utilities import AMPType
2122from pytorch_lightning .utilities .exceptions import MisconfigurationException
2223
2324
24- def do_nothing_closure ():
25+ def do_nothing_closure () -> None :
2526 return
2627
2728
@@ -44,93 +45,86 @@ def __init__(self, optimizer: Optimizer):
4445 self .__class__ = type ("Lightning" + optimizer .__class__ .__name__ , (self .__class__ , optimizer .__class__ ), {})
4546
4647 self ._optimizer = optimizer
47- self ._trainer = None
48- self ._optimizer_idx = None
48+ self ._trainer : Optional [ "pl.Trainer" ] = None
49+ self ._optimizer_idx = 0
4950
5051 @property
51- def optimizer (self ):
52+ def optimizer (self ) -> Optimizer :
5253 return self ._optimizer
5354
5455 @property
55- def defaults (self ):
56+ def defaults (self ) -> dict :
5657 return self ._optimizer .defaults
5758
5859 @defaults .setter
59- def defaults (self , defaults ) :
60+ def defaults (self , defaults : dict ) -> None :
6061 self ._optimizer .defaults = defaults
6162
6263 @property
63- def state (self ):
64+ def state (self ) -> dict :
6465 return self ._optimizer .state
6566
6667 @state .setter
67- def state (self , state ) :
68+ def state (self , state : dict ) -> None :
6869 self ._optimizer .state = state
6970
7071 @property
71- def param_groups (self ):
72+ def param_groups (self ) -> List [ dict ] :
7273 return self ._optimizer .param_groups
7374
7475 @param_groups .setter
75- def param_groups (self , param_groups ) :
76+ def param_groups (self , param_groups : List [ dict ]) -> None :
7677 self ._optimizer .param_groups = param_groups
7778
78- def _on_trainer_init (self , trainer ) :
79+ def _on_trainer_init (self , trainer : "pl.Trainer" ) -> None :
7980 self ._trainer = proxy (trainer )
8081 for opt_idx , opt in enumerate (trainer .optimizers ):
8182 if opt == self ._optimizer :
8283 self ._optimizer_idx = opt_idx
8384 break
8485
8586 @classmethod
86- def _to_lightning_optimizer (cls , optimizer , trainer , opt_idx ) :
87+ def _to_lightning_optimizer (cls , optimizer : Optimizer , trainer : "pl.Trainer" , opt_idx : int ) -> "LightningOptimizer" :
8788 # apex overrides .step function and need to be wrapped on each step
88- if trainer .amp_backend == AMPType .APEX :
89- optimizer = cls (optimizer )
90- optimizer ._on_trainer_init (trainer )
89+ if trainer .amp_backend is not None and trainer . amp_backend == AMPType .APEX :
90+ lightning_optimizer = cls (optimizer )
91+ lightning_optimizer ._on_trainer_init (trainer )
9192 else :
92- optimizer = trainer .lightning_optimizers [opt_idx ]
93- return optimizer
93+ lightning_optimizer = trainer .lightning_optimizers [opt_idx ]
94+ return lightning_optimizer
9495
9596 @contextmanager
96- def toggle_model (self , sync_grad : bool = True ):
97+ def toggle_model (self , sync_grad : bool = True ) -> Generator [ None , None , None ] :
9798 """This function is just a helper for advanced users.
9899
99100 Considering the current optimizer as A and all other optimizers as B.
100101 Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.
101102
102-
103103 When performing gradient accumulation, there is no need to perform grad synchronization
104104 during the accumulation phase.
105105 Setting `sync_grad` to False will block this synchronization and improve performance.
106106 """
107107 # local import here to avoid circular import
108108 from pytorch_lightning .loops .utilities import _block_parallel_sync_behavior
109109
110+ assert self ._trainer is not None
110111 lightning_module = self ._trainer .lightning_module
111112
112113 with _block_parallel_sync_behavior (self ._trainer , block = (not sync_grad )):
113114 lightning_module .toggle_optimizer (self , self ._optimizer_idx )
114115 yield
115116 lightning_module .untoggle_optimizer (self ._optimizer_idx )
116117
117- def step (self , closure : Optional [Callable ] = None , ** kwargs ):
118- """Call this directly from your training_step when doing optimizations manually. By using this we can
119- ensure that all the proper scaling when using 16-bit, accelerator etc is been done properly for you.
120-
121- .. note:: In Manual Optimization, the user is expected to know when to call zero_grad,
122- perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators
118+ def step (self , closure : Optional [Callable [[], Any ]] = None , ** kwargs : Any ) -> None :
119+ """Performs a single optimization step (parameter update).
123120
124121 Args:
125-
126- closure: One could provide its own optimizer_closure. Set to None by default.
127-
128- kwargs: Any parameters provided to wrapped optimizer.step()
122+ closure: An optional optimizer_closure.
123+ kwargs: Any additional arguments to the ``optimizer.step()`` call.
129124
130125 Example::
131126
132- # Scenario for a GAN.
133-
127+ # Scenario for a GAN using manual optimization
134128 def training_step(...):
135129 opt_gen, opt_dis = self.optimizers()
136130
@@ -152,8 +146,7 @@ def training_step(...):
152146 opt_dis.step()
153147
154148
155- # Scenario for a GAN advanced
156-
149+ # A more advanced example
157150 def training_step(self, batch, batch_idx, ...):
158151 opt_gen, opt_dis = self.optimizers()
159152
@@ -189,10 +182,11 @@ def closure_dis():
189182 profiler_action += f"_{ self ._optimizer_idx } "
190183
191184 trainer = self ._trainer
185+ assert trainer is not None
192186 with trainer .profiler .profile (profiler_action ):
193187 trainer .accelerator .optimizer_step (self ._optimizer , self ._optimizer_idx , closure , ** kwargs )
194188
195- def __repr__ (self ):
189+ def __repr__ (self ) -> str :
196190 groups = [
197191 {k : round (v , 12 ) if isinstance (v , float ) else v for k , v in sorted (group .items ()) if k != "params" }
198192 for group in self .param_groups
0 commit comments