Skip to content

Commit 3dc2fec

Browse files
committed
fix ci error
1 parent 8250e11 commit 3dc2fec

File tree

1 file changed

+5
-5
lines changed
  • src/pytorch_lightning/strategies

1 file changed

+5
-5
lines changed

src/pytorch_lightning/strategies/ddp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
from torch.optim.optimizer import Optimizer
3131

3232
import pytorch_lightning as pl
33-
from pytorch_lightning.core.module import LightningModule
3433
from pytorch_lightning.core.optimizer import LightningOptimizer
35-
from pytorch_lightning.overrides import _LightningPrecisionModuleWrapperBase, LightningDistributedModule
34+
from pytorch_lightning.overrides import LightningDistributedModule
35+
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
3636
from pytorch_lightning.overrides.distributed import prepare_for_backward
3737
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
3838
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@@ -305,7 +305,7 @@ def optimizer_step(
305305
def configure_ddp(self) -> None:
306306
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
307307
self.pre_configure_ddp()
308-
assert isinstance(self.model, (LightningModule, _LightningPrecisionModuleWrapperBase))
308+
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
309309
self.model = self._setup_model(LightningDistributedModule(self.model))
310310
self._register_ddp_hooks()
311311

@@ -331,7 +331,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
331331

332332
def pre_backward(self, closure_loss: Tensor) -> None:
333333
"""Run before precision plugin executes backward."""
334-
if isinstance(self.lightning_module, LightningModule) and not self.lightning_module.automatic_optimization:
334+
if isinstance(self.lightning_module, pl.LightningModule) and not self.lightning_module.automatic_optimization:
335335
assert isinstance(self.model, DistributedDataParallel)
336336
prepare_for_backward(self.model, closure_loss)
337337

@@ -384,7 +384,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
384384
return self.model.predict_step(*args, **kwargs)
385385

386386
def post_training_step(self) -> None:
387-
if isinstance(self.lightning_module, LightningModule) and not self.lightning_module.automatic_optimization:
387+
if isinstance(self.lightning_module, pl.LightningModule) and not self.lightning_module.automatic_optimization:
388388
assert self.model is not None
389389
self.model.require_backward_grad_sync = True # type: ignore[assignment]
390390

0 commit comments

Comments
 (0)