diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 0eb4b68651aa8..1d4d6b7b421c0 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -109,6 +109,10 @@ def _is_single_process_single_device(self): def _configure_launcher(self): self._launcher = _SpawnLauncher(self) + @property + def is_interactive_compatible(self) -> bool: + return True + def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) super().setup(trainer) diff --git a/pytorch_lightning/strategies/dp.py b/pytorch_lightning/strategies/dp.py index 484f7b474b02f..1d017eff933ff 100644 --- a/pytorch_lightning/strategies/dp.py +++ b/pytorch_lightning/strategies/dp.py @@ -63,6 +63,10 @@ def node_rank(self) -> int: def world_size(self) -> int: return 1 + @property + def is_interactive_compatible(self) -> bool: + return True + def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index d97797f92daa2..c971b1848ec0f 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -87,6 +87,10 @@ def world_size(self) -> int: def root_device(self) -> torch.device: return xm.xla_device() + @property + def is_interactive_compatible(self) -> bool: + return True + @staticmethod def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None: if not isinstance(dataloaders, list): diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 20c5f485b4e71..ab69a9ee6710a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -593,6 +593,21 @@ def _init_strategy(self) -> None: else: raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}") + from pytorch_lightning.utilities import _IS_INTERACTIVE + + is_interactive_compatible = ( + self.strategy.is_interactive_compatible if hasattr(self.strategy, "is_interactive_compatible") else False + ) + if _IS_INTERACTIVE and not is_interactive_compatible: + raise MisconfigurationException( + f"`Trainer(strategy={self.strategy.strategy_name!r})` or" + f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive" + " environment. Run your code as a script, or choose one of the compatible backends, for example:" + "dp, ddp_spawn, ddp_shard_spawn or tpu_spawn" + " In case you are spawning processes yourself, make sure to include the Trainer" + " creation inside the worker function." + ) + def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() if isinstance(self._precision_plugin_flag, PrecisionPlugin): @@ -713,25 +728,6 @@ def _lazy_init_strategy(self) -> None: self.strategy.set_world_ranks() self.strategy._configure_launcher() - from pytorch_lightning.utilities import _IS_INTERACTIVE - - # TODO move is_compatible logic to strategy API - interactive_compatible_strategy = ( - DataParallelStrategy.strategy_name, - DDPSpawnStrategy.strategy_name, - DDPSpawnShardedStrategy.strategy_name, - TPUSpawnStrategy.strategy_name, - ) - if _IS_INTERACTIVE and self.strategy.strategy_name not in interactive_compatible_strategy: - raise MisconfigurationException( - f"`Trainer(strategy={self.strategy.strategy_name!r})` or" - f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive" - " environment. Run your code as a script, or choose one of the compatible backends:" - f" {', '.join(interactive_compatible_strategy)}." - " In case you are spawning processes yourself, make sure to include the Trainer" - " creation inside the worker function." - ) - # TODO: should be moved to _check_strategy_and_fallback(). # Current test check precision first, so keep this check here to meet error order if isinstance(self.accelerator, TPUAccelerator) and not isinstance( diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 103fc87ecde1b..c12300fb32115 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -124,20 +124,6 @@ class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta): DDP_SHARDED_SPAWN = "ddp_sharded_spawn" DDP_FULLY_SHARDED = "ddp_fully_sharded" - @staticmethod - def interactive_compatible_types() -> list[DistributedType]: - """Returns a list containing interactive compatible DistributeTypes.""" - return [ - DistributedType.DP, - DistributedType.DDP_SPAWN, - DistributedType.DDP_SHARDED_SPAWN, - DistributedType.TPU_SPAWN, - ] - - def is_interactive_compatible(self) -> bool: - """Returns whether self is interactive compatible.""" - return self in DistributedType.interactive_compatible_types() - def deprecate(self) -> None: rank_zero_deprecation( "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8."