Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,11 +991,11 @@ def _run(
if model._compiler_ctx is not None:
supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy]
if self.strategy is not None and not any(isinstance(self.strategy, s) for s in supported_strategies):
supported_strategy_names = " ".join(s.__name__ for s in supported_strategies)
supported_strategy_names = ", ".join(s.__name__ for s in supported_strategies)
raise RuntimeError(
"Using a compiled model is incompatible with the current strategy: "
f"{self.strategy.__class__.__name__}. "
f"Only {supported_strategy_names} support compilation."
f"Only {supported_strategy_names} support compilation. "
"Either switch to one of the supported strategies or avoid passing in "
"a compiled model."
)
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Adam, SGD

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.demos.boring_classes import BoringModel, DemoModel
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_trainer_reference_recursively():
@RunIf(min_torch="1.14.0.dev20221202")
def test_compile_uncompile():

lit_model = BoringModel()
lit_model = DemoModel()
model_compiled = torch.compile(lit_model)

lit_model_compiled = LightningModule.from_compiled(model_compiled)
Expand Down
10 changes: 7 additions & 3 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.demos.boring_classes import (
BoringDataModule,
BoringModel,
DemoModel,
RandomDataset,
RandomIterableDataset,
RandomIterableDatasetWithLen,
Expand Down Expand Up @@ -2243,24 +2245,26 @@ def on_fit_start(self):
# TODO: replace with 1.14 when it is released
@RunIf(min_torch="1.14.0.dev20221202")
def test_trainer_compiled_model():
model = BoringModel()
model = DemoModel()

model = torch.compile(model)

data = BoringDataModule()

trainer = Trainer(
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
)
trainer.fit(model)
trainer.fit(model, data)

assert trainer.model._compiler_ctx["compiler"] == "dynamo"

model = model.to_uncompiled()

assert model._compiler_ctx is None

trainer.train(model)
trainer.fit(model)

assert trainer.model._compiler_ctx is None

Expand Down