3030from torch .optim .optimizer import Optimizer
3131
3232import pytorch_lightning as pl
33- from pytorch_lightning .core .module import LightningModule
3433from pytorch_lightning .core .optimizer import LightningOptimizer
35- from pytorch_lightning .overrides import _LightningPrecisionModuleWrapperBase , LightningDistributedModule
34+ from pytorch_lightning .overrides import LightningDistributedModule
35+ from pytorch_lightning .overrides .base import _LightningPrecisionModuleWrapperBase
3636from pytorch_lightning .overrides .distributed import prepare_for_backward
3737from pytorch_lightning .overrides .fairscale import _FAIRSCALE_AVAILABLE
3838from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
@@ -305,7 +305,7 @@ def optimizer_step(
305305 def configure_ddp (self ) -> None :
306306 log .detail (f"{ self .__class__ .__name__ } : configuring DistributedDataParallel" )
307307 self .pre_configure_ddp ()
308- assert isinstance (self .model , (LightningModule , _LightningPrecisionModuleWrapperBase ))
308+ assert isinstance (self .model , (pl . LightningModule , _LightningPrecisionModuleWrapperBase ))
309309 self .model = self ._setup_model (LightningDistributedModule (self .model ))
310310 self ._register_ddp_hooks ()
311311
@@ -331,7 +331,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
331331
332332 def pre_backward (self , closure_loss : Tensor ) -> None :
333333 """Run before precision plugin executes backward."""
334- if isinstance (self .lightning_module , LightningModule ) and not self .lightning_module .automatic_optimization :
334+ if isinstance (self .lightning_module , pl . LightningModule ) and not self .lightning_module .automatic_optimization :
335335 assert isinstance (self .model , DistributedDataParallel )
336336 prepare_for_backward (self .model , closure_loss )
337337
@@ -384,7 +384,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
384384 return self .model .predict_step (* args , ** kwargs )
385385
386386 def post_training_step (self ) -> None :
387- if isinstance (self .lightning_module , LightningModule ) and not self .lightning_module .automatic_optimization :
387+ if isinstance (self .lightning_module , pl . LightningModule ) and not self .lightning_module .automatic_optimization :
388388 assert self .model is not None
389389 self .model .require_backward_grad_sync = True # type: ignore[assignment]
390390
0 commit comments