Skip to content

Commit bf23171

Browse files
committed
fix wrong assert in predict_step
1 parent 3dc2fec commit bf23171

File tree

1 file changed

+2
-2
lines changed
  • src/pytorch_lightning/strategies

1 file changed

+2
-2
lines changed

src/pytorch_lightning/strategies/ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from pytorch_lightning.utilities.optimizer import optimizers_to_device
6060
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
6161
from pytorch_lightning.utilities.seed import reset_seed
62-
from pytorch_lightning.utilities.types import STEP_OUTPUT, TestStep, TrainingStep, ValidationStep
62+
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, TrainingStep, ValidationStep
6363

6464
if _FAIRSCALE_AVAILABLE:
6565
from fairscale.optim import OSS
@@ -380,7 +380,7 @@ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
380380

381381
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
382382
with self.precision_plugin.predict_step_context():
383-
assert isinstance(self.model, TestStep)
383+
assert isinstance(self.model, PredictStep)
384384
return self.model.predict_step(*args, **kwargs)
385385

386386
def post_training_step(self) -> None:

0 commit comments

Comments
 (0)