Skip to content

Commit 75559fe

Browse files
committed
fix ci error
1 parent c5b9ac7 commit 75559fe

File tree

1 file changed

+6
-4
lines changed
  • src/pytorch_lightning/strategies

1 file changed

+6
-4
lines changed

src/pytorch_lightning/strategies/ddp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from pytorch_lightning.utilities.optimizer import optimizers_to_device
6060
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
6161
from pytorch_lightning.utilities.seed import reset_seed
62-
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, TrainingStep, ValidationStep
62+
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
6363

6464
if _FAIRSCALE_AVAILABLE:
6565
from fairscale.optim import OSS
@@ -333,7 +333,7 @@ def pre_backward(self, closure_loss: Tensor) -> None:
333333
"""Run before precision plugin executes backward."""
334334
assert self.lightning_module is not None
335335
if not self.lightning_module.automatic_optimization:
336-
assert isinstance(self.model, DistributedDataParallel)
336+
assert self.model is not None
337337
prepare_for_backward(self.model, closure_loss)
338338

339339
def model_to_device(self) -> None:
@@ -360,18 +360,20 @@ def reduce(
360360
return tensor
361361

362362
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
363+
assert self.model is not None
363364
with self.precision_plugin.train_step_context():
364-
assert isinstance(self.model, TrainingStep)
365365
return self.model(*args, **kwargs)
366366

367367
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
368368
with self.precision_plugin.val_step_context():
369-
assert isinstance(self.model, ValidationStep)
369+
assert self.lightning_module is not None
370+
assert self.model is not None
370371
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
371372
# used when calling `trainer.fit`
372373
return self.model(*args, **kwargs)
373374
else:
374375
# used when calling `trainer.validate`
376+
assert isinstance(self.model, ValidationStep)
375377
return self.model.validation_step(*args, **kwargs)
376378

377379
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:

0 commit comments

Comments
 (0)