1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from contextlib import contextmanager
15- from typing import Dict , Generator , List , Tuple , Union
15+ from typing import Dict , Generator , List , Tuple
1616
1717from torch import Tensor
1818from torch .nn import Module
1919from torch .optim import Optimizer
2020
2121import pytorch_lightning as pl
2222from pytorch_lightning .core .optimizer import LightningOptimizer
23- from pytorch_lightning .overrides .base import _LightningModuleWrapperBase
23+ from pytorch_lightning .overrides .base import _LightningModuleWrapperBase , _LightningPrecisionModuleWrapperBase
2424from pytorch_lightning .strategies .ddp import DDPStrategy
2525from pytorch_lightning .trainer .states import TrainerFn
2626from pytorch_lightning .utilities .enums import PrecisionType
@@ -51,10 +51,11 @@ def connect(self, model: "pl.LightningModule") -> None:
5151
5252 def setup (self , trainer : "pl.Trainer" ) -> None :
5353 # share ddp pids to all processes
54- self ._rank_0_will_call_children_scripts = self .broadcast (self ._rank_0_will_call_children_scripts )
54+ self ._rank_0_will_call_children_scripts : bool = self .broadcast (self ._rank_0_will_call_children_scripts )
5555 if self ._should_run_deadlock_detection ():
5656 self ._share_information_to_prevent_deadlock ()
5757
58+ assert self .accelerator is not None
5859 self .accelerator .setup (trainer )
5960
6061 # move the model to the correct device
@@ -64,6 +65,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
6465 trainer_fn = trainer .state .fn
6566 if trainer_fn == TrainerFn .FITTING :
6667 if self ._layer_sync :
68+ assert self .model is not None
6769 self .model = self ._layer_sync .apply (self .model )
6870
6971 self .setup_precision_plugin ()
@@ -73,7 +75,9 @@ def setup(self, trainer: "pl.Trainer") -> None:
7375
7476 def configure_ddp (self ) -> None :
7577 self ._set_ddp_kwargs ()
76- self .setup_optimizers (self .model .trainer )
78+ assert self .lightning_module is not None
79+ self .setup_optimizers (self .lightning_module .trainer )
80+ assert isinstance (self .model , (pl .LightningModule , _LightningPrecisionModuleWrapperBase ))
7781 self .model , self .optimizers = self ._setup_model_and_optimizers (
7882 model = _LightningModuleWrapperBase (self .model ),
7983 optimizers = self .optimizers ,
@@ -97,12 +101,13 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
97101 return model , optimizers
98102
99103 def _wrap_optimizers (self , optimizers : List [Optimizer ]) -> List ["OSS" ]:
100- if self .model is not None and self .model .trainer .state .fn != TrainerFn .FITTING :
104+ assert self .lightning_module is not None
105+ if self .model is not None and self .lightning_module .trainer .state .fn != TrainerFn .FITTING :
101106 return optimizers
102107
103108 return self ._reinit_optimizers_with_oss (optimizers )
104109
105- def _reinit_optimizers_with_oss (self , optimizers : List [Union [ Optimizer , LightningOptimizer ] ]) -> List ["OSS" ]:
110+ def _reinit_optimizers_with_oss (self , optimizers : List [Optimizer ]) -> List ["OSS" ]:
106111 for x , optimizer in enumerate (optimizers ):
107112 if isinstance (optimizer , LightningOptimizer ):
108113 optimizer = optimizer ._optimizer
@@ -135,7 +140,7 @@ def block_backward_sync(self) -> Generator:
135140 else :
136141 yield None
137142
138- def post_training_step (self ):
143+ def post_training_step (self ) -> None :
139144 pass
140145
141146 @classmethod
0 commit comments