1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Callable , List , Tuple
14+ from typing import Any , Callable , Generator , List , Sequence , Tuple , Type , TYPE_CHECKING
1515
1616import torch
17- from torch .optim import Optimizer
1817
1918from pytorch_lightning .core import LightningModule
2019from pytorch_lightning .plugins .precision .mixed import MixedPrecisionPlugin
2322if _APEX_AVAILABLE :
2423 from apex import amp
2524
25+ if TYPE_CHECKING :
26+ from torch .optim import Optimizer
27+
2628
2729class ApexMixedPrecisionPlugin (MixedPrecisionPlugin ):
2830 """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
2931
30- def __init__ (self , amp_level : str ):
32+ def __init__ (self , amp_level : str ) -> None :
3133 self .backend = AMPType .APEX
3234 self .amp_level = amp_level
3335
34- def master_params (self , optimizer : torch . optim . Optimizer ) :
36+ def master_params (self , optimizer : ' Optimizer' ) -> Generator [ torch . Tensor , None , None ] :
3537 return amp .master_params (optimizer )
3638
37- def connect (self , model : torch .nn .Module , optimizers , lr_schedulers ):
39+ def connect (self , model : torch .nn .Module , optimizers : Sequence ['Optimizer' ],
40+ lr_schedulers : Sequence [Any ]) -> Tuple [torch .nn .Module , Sequence ['Optimizer' ], Sequence [Any ]]:
3841 """Connects the precision plugin to the training process,
3942 configures apex and reinits the schedulers
4043 """
4144 if model .device .type != "cuda" :
4245 return model , optimizers , lr_schedulers
43- model , optimizers = self .configure_apex (amp , model , optimizers , self .amp_level )
46+ model , optimizers = self .configure_apex (amp , model , list ( optimizers ) , self .amp_level )
4447 self .reinit_scheduler_properties (optimizers , lr_schedulers )
4548 return model , optimizers , lr_schedulers
4649
4750 def backward (
4851 self ,
4952 model : LightningModule ,
5053 closure_loss : torch .Tensor ,
51- optimizer : torch . optim . Optimizer ,
54+ optimizer : ' Optimizer' ,
5255 opt_idx : int ,
5356 should_accumulate : bool ,
54- * args ,
55- ** kwargs ,
56- ):
57+ * args : Any ,
58+ ** kwargs : Any ,
59+ ) -> torch . Tensor :
5760 """performs the actual backpropagation
5861
5962 Args:
@@ -94,11 +97,11 @@ def backward(
9497
9598 def configure_apex (
9699 self ,
97- amp : object ,
100+ amp : Type ,
98101 model : LightningModule ,
99- optimizers : List [Optimizer ],
102+ optimizers : List [' Optimizer' ],
100103 amp_level : str ,
101- ) -> Tuple [LightningModule , List [Optimizer ]]:
104+ ) -> Tuple [LightningModule , List [' Optimizer' ]]:
102105 r"""
103106 Override to init AMP your own way.
104107 Must return a model and list of optimizers.
@@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
127130 return model , optimizers
128131
129132 @staticmethod
130- def reinit_scheduler_properties (optimizers : list , schedulers : list ) :
133+ def reinit_scheduler_properties (optimizers : Sequence [ 'Optimizer' ] , schedulers : Sequence [ Any ]) -> None :
131134 """Reinitializes schedulers with correct properties"""
132135 # Reinitialize optimizer.step properties added by schedulers
133136 for scheduler in schedulers :
@@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
149152 break
150153
151154 def pre_optimizer_step (
152- self , pl_module : LightningModule , optimizer : Optimizer , optimizer_idx : int , lambda_closure : Callable , ** kwargs
155+ self ,
156+ pl_module : LightningModule ,
157+ optimizer : 'Optimizer' ,
158+ optimizer_idx : int ,
159+ lambda_closure : Callable ,
160+ ** kwargs : Any ,
153161 ) -> bool :
154162 """
155163 always called before the optimizer step.
@@ -160,6 +168,6 @@ def pre_optimizer_step(
160168 if not pl_module .automatic_optimization :
161169 pl_module .trainer .call_hook ("on_after_backward" )
162170
163- optimizer .step ()
171+ optimizer .step (** kwargs )
164172
165173 return False
0 commit comments