Skip to content

Commit c5b9ac7

Browse files
committed
fix assertion
1 parent bf23171 commit c5b9ac7

File tree

1 file changed

+4
-2
lines changed
  • src/pytorch_lightning/strategies

1 file changed

+4
-2
lines changed

src/pytorch_lightning/strategies/ddp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
331331

332332
def pre_backward(self, closure_loss: Tensor) -> None:
333333
"""Run before precision plugin executes backward."""
334-
if isinstance(self.lightning_module, pl.LightningModule) and not self.lightning_module.automatic_optimization:
334+
assert self.lightning_module is not None
335+
if not self.lightning_module.automatic_optimization:
335336
assert isinstance(self.model, DistributedDataParallel)
336337
prepare_for_backward(self.model, closure_loss)
337338

@@ -384,7 +385,8 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
384385
return self.model.predict_step(*args, **kwargs)
385386

386387
def post_training_step(self) -> None:
387-
if isinstance(self.lightning_module, pl.LightningModule) and not self.lightning_module.automatic_optimization:
388+
assert self.lightning_module is not None
389+
if not self.lightning_module.automatic_optimization:
388390
assert self.model is not None
389391
self.model.require_backward_grad_sync = True # type: ignore[assignment]
390392

0 commit comments

Comments
 (0)