diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f95ebaebac6c..933f570cf541b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015)) +- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022)) + ### Deprecated diff --git a/docs/source/extensions/accelerators.rst b/docs/source/extensions/accelerators.rst index 4db625adea5dc..753749d8a3730 100644 --- a/docs/source/extensions/accelerators.rst +++ b/docs/source/extensions/accelerators.rst @@ -25,11 +25,10 @@ One to handle differences from the training routine and one to handle different from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin - accelerator = GPUAccelerator( - precision_plugin=NativeMixedPrecisionPlugin(precision=16, device="cuda"), - training_type_plugin=DDPPlugin(), - ) - trainer = Trainer(accelerator=accelerator) + accelerator = GPUAccelerator() + precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda") + training_type_plugin = DDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin) + trainer = Trainer(strategy=training_type_plugin) We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index 78c6503fea34d..f791df894d0c8 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -80,11 +80,10 @@ can then be passed into the Trainer directly or via a (custom) accelerator: trainer = Trainer(strategy=CustomDDPPlugin(), plugins=[CustomPrecisionPlugin()]) # fully custom accelerator and plugins - accelerator = MyAccelerator( - precision_plugin=CustomPrecisionPlugin(), - training_type_plugin=CustomDDPPlugin(), - ) - trainer = Trainer(accelerator=accelerator) + accelerator = MyAccelerator() + precision_plugin = MyPrecisionPlugin() + training_type_plugin = CustomDDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin) + trainer = Trainer(strategy=training_type_plugin) The full list of built-in plugins is listed below. diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 18fd855c94a60..093065394b337 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union import torch -from torch.nn import Module import pytorch_lightning as pl -from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type import TrainingTypePlugin class Accelerator: @@ -31,35 +28,14 @@ class Accelerator: - GPU - TPU - IPU - - Each Accelerator gets two plugins upon initialization: - One to handle differences from the training routine and one to handle different precisions. """ - def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None: - """ - Args: - precision_plugin: the plugin to handle precision-specific parts - - .. deprecated:: - The ``precision_plugin`` parameter has been deprecated and will be removed soon. - Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. - - training_type_plugin: the plugin to handle different training routines - """ - - self.training_type_plugin = training_type_plugin - - if precision_plugin is not None: - self.training_type_plugin._precision_plugin = precision_plugin - - def setup_environment(self) -> None: + def setup_environment(self, root_device: torch.device) -> None: """Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. """ - self.training_type_plugin.setup_environment() def setup(self, trainer: "pl.Trainer") -> None: """Setup plugins for the trainer fit and creates optimizers. @@ -67,40 +43,6 @@ def setup(self, trainer: "pl.Trainer") -> None: Args: trainer: the trainer instance """ - self.training_type_plugin.setup(trainer) - - @property - def model(self) -> Module: - """Returns the model. - - This can also be a wrapped LightningModule. For retrieving the pure LightningModule use - :attr:`Accelerator.lightning_module` - """ - return self.training_type_plugin.model - - @model.setter - def model(self, new_model: Module) -> None: - self.training_type_plugin.model = new_model - - @property - def lightning_module(self) -> "pl.LightningModule": - """Returns the pure LightningModule. - - To get the potentially wrapped model use :attr:`Accelerator.model` - """ - return self.training_type_plugin.lightning_module - - @property - def root_device(self) -> torch.device: - """Returns the root device.""" - return self.training_type_plugin.root_device - - def teardown(self) -> None: - """This method is called to teardown the training process. - - It is the right place to release memory and free other resources. - """ - self.training_type_plugin.teardown() def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for a given device. diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 7d5786102d0b3..40c9a3c2b918c 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -15,7 +15,6 @@ import torch -import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,18 +22,14 @@ class CPUAccelerator(Accelerator): """Accelerator for CPU devices.""" - def setup(self, trainer: "pl.Trainer") -> None: + def setup_environment(self, root_device: torch.device) -> None: """ Raises: MisconfigurationException: If the selected device is not CPU. """ - if "cpu" not in str(self.training_type_plugin.root_device): - raise MisconfigurationException( - f"Device should be CPU, got {self.training_type_plugin.root_device} instead." - ) - - return super().setup(trainer) + if "cpu" not in str(root_device): + raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.") def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """CPU device stats aren't supported yet.""" diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index c6c82d83c32f5..06ade654fca92 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -30,22 +30,19 @@ class GPUAccelerator(Accelerator): """Accelerator for GPU devices.""" - def setup_environment(self) -> None: + def setup_environment(self, root_device: torch.device) -> None: """ Raises: MisconfigurationException: If the selected device is not GPU. """ - super().setup_environment() - if "cuda" not in str(self.training_type_plugin.root_device): - raise MisconfigurationException( - f"Device should be GPU, got {self.training_type_plugin.root_device} instead" - ) - torch.cuda.set_device(self.training_type_plugin.root_device) + if "cuda" not in str(root_device): + raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") + torch.cuda.set_device(root_device) def setup(self, trainer: "pl.Trainer") -> None: + # TODO refactor input from trainer to local_rank @four4fish self.set_nvidia_flags(trainer.local_rank) - super().setup(trainer) # clear cache before training torch.cuda.empty_cache() @@ -74,10 +71,6 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: return torch.cuda.memory_stats(device) return get_nvidia_gpu_stats(device) - def teardown(self) -> None: - super().teardown() - self.training_type_plugin._move_optimizer_state(torch.device("cpu")) - @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index f116ed7f0f493..34c37dcd95e7f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -15,11 +15,7 @@ import torch -import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.plugins.precision import TPUPrecisionPlugin -from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin -from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.utilities import _XLA_AVAILABLE if _XLA_AVAILABLE: @@ -29,25 +25,6 @@ class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" - def setup(self, trainer: "pl.Trainer") -> None: - """ - Raises: - ValueError: - If the precision or training type plugin are unsupported. - """ - if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin): - # this configuration should have been avoided in the accelerator connector - raise ValueError( - f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," - f" found: {self.training_type_plugin.precision_plugin}." - ) - if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): - raise ValueError( - "The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin," - f" found {self.training_type_plugin}." - ) - return super().setup(trainer) - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for the given TPU device. diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 35fe3d053d0d4..a07ed1cc3dfab 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -99,8 +99,8 @@ def __init__( amp_level=None, plugins=plugins, ) - self._accelerator = self._accelerator_connector.accelerator - self._strategy = self._accelerator.training_type_plugin + self._strategy = self._accelerator_connector.training_type_plugin + self._accelerator = self._strategy.accelerator self._precision_plugin = self._strategy.precision_plugin self._models_setup: int = 0 @@ -398,7 +398,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) return seed_everything(seed=seed, workers=workers) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - self._accelerator.setup_environment() + self._strategy.setup_environment() # apply sharded context to prevent OOM run_method = partial(self._run_with_sharded_context, run_method) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 829735b0e0bed..46c6f4d0ac1bb 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -84,6 +84,7 @@ class DDPPlugin(ParallelPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, @@ -95,6 +96,7 @@ def __init__( **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, @@ -147,6 +149,7 @@ def setup_environment(self) -> None: self._call_children_scripts() self.setup_distributed() + super().setup_environment() def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 975f4ba435b2d..76b0db3e1370d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -62,6 +62,7 @@ class DDPSpawnPlugin(ParallelPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, @@ -72,6 +73,7 @@ def __init__( **kwargs: Any, ): super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f30d15d495f9f..88fa34905ba1e 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -88,6 +88,7 @@ class DeepSpeedPlugin(DDPPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, zero_optimization: bool = True, stage: int = 2, remote_device: str = "cpu", @@ -273,6 +274,7 @@ def __init__( ) super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=cluster_environment, precision_plugin=precision_plugin, diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 69ba2fed867a7..decadb3f0ce5d 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -35,11 +35,13 @@ class DataParallelPlugin(ParallelPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io, diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index d1b1257622beb..475701e13f593 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -37,6 +37,7 @@ class DDPFullyShardedPlugin(DDPPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, cpu_offload: bool = False, flatten_parameters: bool = True, reshard_after_forward: bool = True, @@ -98,6 +99,7 @@ def __init__( """ super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 184183f5775e3..858d290b20d5b 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -41,11 +41,13 @@ class HorovodPlugin(ParallelPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io, diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 2763ad645facb..9a1ddaf9b38d1 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -62,6 +62,7 @@ class IPUPlugin(ParallelPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, device_iterations: int = 1, autoreport: bool = False, autoreport_dir: Optional[str] = None, @@ -86,6 +87,7 @@ def __init__( created options for validation/testing and predicting. """ super().__init__( + accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 293e52170d4b8..9e1967fc64409 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -34,12 +34,13 @@ class ParallelPlugin(TrainingTypePlugin, ABC): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) + super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 0159e86412cbf..ca95330281cb0 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -28,10 +28,11 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__( self, device: torch.device, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) + super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 011604468e1f5..34bb0b01f4ffd 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -34,6 +34,7 @@ class SingleTPUPlugin(SingleDevicePlugin): def __init__( self, device: int, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, @@ -41,7 +42,9 @@ def __init__( device = xm.xla_device(device) checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) + super().__init__( + accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin + ) self.debug = debug self.tpu_local_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 013b73459746f..f6c85b060e0a0 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -54,6 +54,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def __init__( self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[int]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, @@ -62,7 +63,10 @@ def __init__( ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() super().__init__( - parallel_devices=parallel_devices, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin + accelerator=accelerator, + parallel_devices=parallel_devices, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.debug = debug self.tpu_local_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 0c7e1f8410e57..bfdb01577aafe 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -42,12 +42,16 @@ class TrainingTypePlugin(ABC): loop.""" def __init__( - self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None + self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: + self._accelerator = accelerator self._model: Optional[Module] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io - self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() + self._precision_plugin = precision_plugin self.optimizers: List[Optimizer] = [] self.lr_schedulers: List[_LRScheduler] = [] self.optimizer_frequencies: List[int] = [] @@ -57,13 +61,25 @@ def __init__( f" Move your implementation to `{self.__class__.__name__}.teardown()` instead." ) + @property + def accelerator(self) -> "pl.accelerators.accelerator.Accelerator": + return self._accelerator + + @accelerator.setter + def accelerator(self, accelerator: "pl.accelerators.accelerator.Accelerator") -> None: + self._accelerator = accelerator + @property def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io @property def precision_plugin(self) -> PrecisionPlugin: - return self._precision_plugin + return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() + + @precision_plugin.setter + def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: + self._precision_plugin = precision_plugin @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: @@ -79,6 +95,7 @@ def setup_environment(self) -> None: This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. """ + self.accelerator.setup_environment(self.root_device) def setup_optimizers(self, trainer: "pl.Trainer") -> None: """Creates optimizers and schedulers. @@ -101,6 +118,7 @@ def setup(self, trainer: "pl.Trainer") -> None: Args: trainer: the trainer instance """ + self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() @@ -425,6 +443,7 @@ def teardown(self) -> None: It is the right place to release memory and free other resources. """ + self._move_optimizer_state(torch.device("cpu")) @classmethod def register_plugins(cls, plugin_registry) -> None: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 18a4da416946d..0154bc94acac1 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -174,8 +174,9 @@ def __init__( self._validate_accelerator_type() self._set_devices_if_none() - self._training_type_plugin_resolved = False - self.accelerator = self.select_accelerator() + self.training_type_plugin = self.final_training_type_plugin() + self.accelerator = self.training_type_plugin.accelerator + self._check_plugin_compatibility() # benchmarking # TODO: should this be moved to GPU accelerator? @@ -394,22 +395,23 @@ def precision_plugin(self) -> PrecisionPlugin: self._precision_plugin = self.select_precision_plugin() return self._precision_plugin - @property - def training_type_plugin(self) -> TrainingTypePlugin: - if self._training_type_plugin_resolved: - # avoid calling `resolve_training_type_plugin` multiple times - return self._training_type_plugin + def final_training_type_plugin(self) -> TrainingTypePlugin: if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) # attach checkpoint plugin to the training type plugin if self._checkpoint_io is not None: self._training_type_plugin.checkpoint_io = self._checkpoint_io - precision_plugin = self.precision_plugin - if precision_plugin is not None: - self._training_type_plugin._precision_plugin = precision_plugin - self._training_type_plugin_resolved = True - + if ( + isinstance(self.strategy, TrainingTypePlugin) and self.strategy._precision_plugin is None + ) or not isinstance(self.strategy, TrainingTypePlugin): + precision_plugin = self.precision_plugin + if precision_plugin is not None: + self._training_type_plugin.precision_plugin = precision_plugin + if (isinstance(self.strategy, TrainingTypePlugin) and self.strategy.accelerator is None) or not isinstance( + self.strategy, TrainingTypePlugin + ): + self._training_type_plugin.accelerator = self.select_accelerator() return self._training_type_plugin @property @@ -797,10 +799,7 @@ def select_accelerator(self) -> Accelerator: else: acc_cls = CPUAccelerator - accelerator = acc_cls(precision_plugin=None, training_type_plugin=self.training_type_plugin) - # transfer ownership of the plugins to the accelerator - self._training_type_plugin = proxy(self.training_type_plugin) - + accelerator = acc_cls() return accelerator def select_cluster_environment(self) -> ClusterEnvironment: @@ -1016,3 +1015,21 @@ def _is_slurm_managing_tasks(self) -> bool: total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0) return num_slurm_tasks == total_requested_devices + + def _check_plugin_compatibility(self) -> None: + """Checks that selected plugins are compatible with each other. + + Raises: + ValueError: If an invalid combination of Accelerator, TrainingTypePlugin, PrecisionPlugin is found. + """ + if isinstance(self.accelerator, TPUAccelerator): + if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin): + raise ValueError( + f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," + f" found: {self.training_type_plugin.precision_plugin}." + ) + if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): + raise ValueError( + "The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin`," + f" found {self.training_type_plugin}." + ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0455de7c278bd..b0ef15e33a46e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1106,7 +1106,7 @@ def _run( # SET UP TRAINING # ---------------------------- self._call_callback_hooks("on_before_accelerator_backend_setup") - self.accelerator.setup_environment() + self.training_type_plugin.setup_environment() self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later @@ -1114,7 +1114,7 @@ def _run( self._restore_modules_and_callbacks(ckpt_path) self._call_configure_sharded_model() # allow user to setup in model sharded environment - self.accelerator.setup(self) + self.training_type_plugin.setup(self) # ---------------------------- # INSPECT THE CORE LOOPS @@ -1124,7 +1124,7 @@ def _run( {Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || spawn processes || - {self.accelerator.setup_environment} || + {self.training_type_plugin.setup_environment} || | || setup accelerator || and strategy || LIGHTNING @@ -1231,7 +1231,7 @@ def _teardown(self): """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; those are handled by :meth:`_call_teardown_hook`.""" self.training_type_plugin.post_dispatch(self) - self.accelerator.teardown() + self.training_type_plugin.teardown() self._data_connector.teardown() self._active_loop.teardown() self.logger_connector.teardown() @@ -1670,11 +1670,11 @@ def _on_exception(self) -> None: @property def accelerator(self) -> Accelerator: - return self._accelerator_connector.accelerator + return self.training_type_plugin.accelerator @property def training_type_plugin(self) -> TrainingTypePlugin: - return self.accelerator.training_type_plugin + return self._accelerator_connector.training_type_plugin @property def precision_plugin(self) -> PrecisionPlugin: @@ -1748,7 +1748,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: @property def lightning_module(self) -> "pl.LightningModule": - return self.accelerator.lightning_module + return self.training_type_plugin.lightning_module @property def optimizers(self) -> List[Optimizer]: @@ -1806,7 +1806,7 @@ def model(self) -> torch.nn.Module: To access the pure LightningModule, use :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. """ - return self.accelerator.model + return self.training_type_plugin.model @model.setter def model(self, model: torch.nn.Module) -> None: @@ -1817,7 +1817,7 @@ def model(self, model: torch.nn.Module) -> None: model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending on the backend. """ - self.accelerator.model = model + self.training_type_plugin.model = model """ General properties diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 51316c155368c..0ef10b4eb2a9f 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -397,7 +397,14 @@ def creates_processes_externally(self) -> bool: @mock.patch.dict( os.environ, - {"SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", "SLURM_LOCALID": "0"}, + { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + }, ) @mock.patch("torch.cuda.device_count", return_value=0) @mock.patch("pytorch_lightning.plugins.DDPPlugin.setup_distributed", autospec=True) @@ -411,9 +418,8 @@ class Prec(PrecisionPlugin): class TrainTypePlugin(SingleDevicePlugin): pass - ttp = TrainTypePlugin(device=torch.device("cpu")) - accelerator = Accel(training_type_plugin=ttp, precision_plugin=Prec()) - trainer = Trainer(accelerator=accelerator, fast_dev_run=True, num_processes=2) + ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec()) + trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2) assert isinstance(trainer.accelerator, Accel) assert isinstance(trainer.training_type_plugin, TrainTypePlugin) assert isinstance(trainer.precision_plugin, Prec) @@ -422,9 +428,8 @@ class TrainTypePlugin(SingleDevicePlugin): class DistributedPlugin(DDPPlugin): pass - ttp = DistributedPlugin() - accelerator = Accel(training_type_plugin=ttp, precision_plugin=Prec()) - trainer = Trainer(accelerator=accelerator, fast_dev_run=True, num_processes=2) + ttp = DistributedPlugin(accelerator=Accel(), precision_plugin=Prec()) + trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2) assert isinstance(trainer.accelerator, Accel) assert isinstance(trainer.training_type_plugin, DistributedPlugin) assert isinstance(trainer.precision_plugin, Prec) @@ -483,11 +488,11 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str): else: with pytest.deprecated_call(match=r"accelerator=.*\)` has been deprecated"): trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) - assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + assert isinstance(trainer.training_type_plugin, DDPShardedPlugin) with pytest.deprecated_call(match="Passing .* `strategy` to the `plugins`"): trainer = Trainer(plugins=plugin, num_processes=2) - assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + assert isinstance(trainer.training_type_plugin, DDPShardedPlugin) @pytest.mark.parametrize( @@ -1029,10 +1034,13 @@ def test_unsupported_tpu_choice(monkeypatch): with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64) - with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"): - Trainer(accelerator="tpu", precision=16) - with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"): - Trainer(accelerator="tpu", precision=16, amp_backend="apex") + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"): + with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"): + Trainer(accelerator="tpu", precision=16) + + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"): + with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"): + Trainer(accelerator="tpu", precision=16, amp_backend="apex") def test_unsupported_ipu_choice(monkeypatch): diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 553d842ed186e..2ef234b1ffde6 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -16,9 +16,9 @@ def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" - plugin = SingleDevicePlugin(torch.device("cpu")) - accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch + plugin = SingleDevicePlugin( + accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=PrecisionPlugin() + ) assert not plugin.restore_checkpoint_after_pre_dispatch @@ -49,14 +49,16 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) - accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - - assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch + plugin = TestPlugin( + accelerator=CPUAccelerator(), + precision_plugin=PrecisionPlugin(), + device=torch.device("cpu"), + checkpoint_io=TorchCheckpointIO(), + ) assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch - trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True) + trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): - accelerator.training_type_plugin.predispatched_called = False + plugin.predispatched_called = False func(model, ckpt_path=checkpoint_path) diff --git a/tests/accelerators/test_gpu.py b/tests/accelerators/test_gpu.py index 85ce0cd9f0f18..ece78da24972d 100644 --- a/tests/accelerators/test_gpu.py +++ b/tests/accelerators/test_gpu.py @@ -1,8 +1,6 @@ import torch from pytorch_lightning.accelerators import GPUAccelerator -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin from tests.helpers.runif import RunIf @@ -11,10 +9,7 @@ def test_get_torch_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch >= 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") - GPUAccel = GPUAccelerator( - training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() - ) - gpu_stats = GPUAccel.get_device_stats(current_device) + gpu_stats = GPUAccelerator().get_device_stats(current_device) fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] for f in fields: @@ -26,10 +21,7 @@ def test_get_torch_gpu_stats(tmpdir): def test_get_nvidia_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch < 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") - GPUAccel = GPUAccelerator( - training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() - ) - gpu_stats = GPUAccel.get_device_stats(current_device) + gpu_stats = GPUAccelerator().get_device_stats(current_device) fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"] for f in fields: diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 87154efbd478a..d404db964a87d 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -188,7 +188,7 @@ def test_optimization(tmpdir): def test_mixed_precision(tmpdir): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: - assert trainer.accelerator.model.precision == 16 + assert trainer.training_type_plugin.model.precision == 16 raise SystemExit model = IPUModel() @@ -203,8 +203,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st def test_pure_half_precision(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.accelerator.model.precision == 16 - for param in trainer.accelerator.model.parameters(): + assert trainer.training_type_plugin.model.precision == 16 + for param in trainer.training_type_plugin.model.parameters(): assert param.dtype == torch.float16 raise SystemExit @@ -212,7 +212,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: model = model.half() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) + assert isinstance(trainer.training_type_plugin, IPUPlugin) assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) assert trainer.training_type_plugin.precision_plugin.precision == 16 @@ -224,9 +224,9 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: def test_device_iterations_ipu_plugin(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.accelerator.training_type_plugin.device_iterations == 2 + assert trainer.training_type_plugin.device_iterations == 2 # assert device iterations has been set correctly within the poptorch options - poptorch_model = trainer.accelerator.training_type_plugin.poptorch_models[RunningStage.TRAINING] + poptorch_model = trainer.training_type_plugin.poptorch_models[RunningStage.TRAINING] assert poptorch_model._options.toDict()["device_iterations"] == 2 raise SystemExit @@ -238,7 +238,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: strategy=IPUPlugin(device_iterations=2), callbacks=TestCallback(), ) - assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) + assert isinstance(trainer.training_type_plugin, IPUPlugin) with pytest.raises(SystemExit): trainer.fit(model) @@ -251,7 +251,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # since ipu handle accumulation assert trainer.accumulation_scheduler.scheduling == {0: 1} # assert poptorch option have been set correctly - poptorch_model = trainer.accelerator.training_type_plugin.poptorch_models[RunningStage.TRAINING] + poptorch_model = trainer.training_type_plugin.poptorch_models[RunningStage.TRAINING] assert poptorch_model._options.Training.toDict()["gradient_accumulation"] == 2 raise SystemExit @@ -356,9 +356,9 @@ def test_manual_poptorch_opts(tmpdir): ) trainer.fit(model) - assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) - assert trainer.accelerator.training_type_plugin.training_opts == training_opts - assert trainer.accelerator.training_type_plugin.inference_opts == inference_opts + assert isinstance(trainer.training_type_plugin, IPUPlugin) + assert trainer.training_type_plugin.training_opts == training_opts + assert trainer.training_type_plugin.inference_opts == inference_opts @RunIf(ipu=True) @@ -380,7 +380,7 @@ def test_manual_poptorch_opts_custom(tmpdir): class TestCallback(Callback): def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # ensure dataloaders were correctly set up during training. - plugin = trainer.accelerator.training_type_plugin + plugin = trainer.training_type_plugin assert isinstance(plugin, IPUPlugin) assert plugin.training_opts.replication_factor == 2 assert plugin.inference_opts.replication_factor == 1 @@ -400,7 +400,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin, callbacks=TestCallback()) trainer.fit(model) - plugin = trainer.accelerator.training_type_plugin + plugin = trainer.training_type_plugin assert isinstance(plugin, IPUPlugin) training_opts = plugin.training_opts @@ -462,9 +462,9 @@ def test_default_opts(tmpdir): trainer = Trainer(default_root_dir=tmpdir, ipus=1, fast_dev_run=True) trainer.fit(model) - assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) - inference_opts = trainer.accelerator.training_type_plugin.inference_opts - training_opts = trainer.accelerator.training_type_plugin.training_opts + assert isinstance(trainer.training_type_plugin, IPUPlugin) + inference_opts = trainer.training_type_plugin.inference_opts + training_opts = trainer.training_type_plugin.training_opts for opts in (inference_opts, training_opts): assert isinstance(opts, poptorch.Options) assert opts.Training.gradient_accumulation == 1 diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index fc1ce413cd494..65d607fc32ef3 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -13,7 +13,7 @@ # limitations under the License import collections from copy import deepcopy -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import torch @@ -288,25 +288,27 @@ def forward(self, x): def test_tpu_invalid_raises(): - accelerator = TPUAccelerator(object(), TPUSpawnPlugin()) + training_type_plugin = TPUSpawnPlugin(accelerator=TPUAccelerator(), precision_plugin=Mock()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): - accelerator.setup(object()) + Trainer(strategy=training_type_plugin) - accelerator = TPUAccelerator(TPUPrecisionPlugin(), DDPPlugin()) - with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): - accelerator.setup(object()) + training_type_plugin = DDPPlugin(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"): + Trainer(strategy=training_type_plugin) def test_tpu_invalid_raises_set_precision_with_strategy(): - accelerator = TPUAccelerator(object(), TPUSpawnPlugin(precision_plugin=object())) + accelerator = TPUAccelerator() + training_type_plugin = TPUSpawnPlugin(accelerator=accelerator, precision_plugin=object()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): - accelerator.setup(object()) + Trainer(strategy=training_type_plugin) - accelerator = TPUAccelerator(None, DDPPlugin(precision_plugin=TPUPrecisionPlugin())) + accelerator = TPUAccelerator() + training_type_plugin = DDPPlugin(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( - ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin" + ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin" ): - accelerator.setup(object()) + Trainer(strategy=training_type_plugin) @RunIf(tpu=True) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 6967ea9a12bd7..c4a2eeaf74c0b 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -23,8 +23,8 @@ def test_invalid_on_cpu(tmpdir): MisconfigurationException, match="You selected accelerator to be `ddp_fully_sharded`, but GPU is not available." ): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp") - assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - trainer.accelerator.setup_environment() + assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin) + trainer.training_type_plugin.setup_environment() @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 1aaf89d052686..e99474efd8a7e 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -56,11 +56,11 @@ def test_ddp_with_2_gpus(): class BarrierModel(BoringModel): def setup(self, stage=None): - assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel) + assert not isinstance(self.trainer.training_type_plugin.model, DistributedDataParallel) self.trainer.training_type_plugin.barrier("barrier before model is wrapped") def on_train_start(self): - assert isinstance(self.trainer.accelerator.model, DistributedDataParallel) + assert isinstance(self.trainer.training_type_plugin.model, DistributedDataParallel) self.trainer.training_type_plugin.barrier("barrier after model is wrapped") @@ -110,8 +110,8 @@ def test_ddp_configure_ddp(): # test wrap the model if fitting trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.connect(model) - trainer.accelerator.setup_environment() - trainer.accelerator.setup(trainer) + trainer.training_type_plugin.setup_environment() + trainer.training_type_plugin.setup(trainer) trainer.lightning_module.trainer = trainer assert isinstance(trainer.model, LightningModule) trainer._pre_dispatch() @@ -124,8 +124,8 @@ def test_ddp_configure_ddp(): ) # test do not wrap the model if trainerFN is not fitting trainer.training_type_plugin.connect(model) - trainer.accelerator.setup_environment() - trainer.accelerator.setup(trainer) + trainer.training_type_plugin.setup_environment() + trainer.training_type_plugin.setup(trainer) trainer.lightning_module.trainer = trainer trainer._pre_dispatch() # in DDPPlugin configure_ddp(), model are still LightningModule diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index b67988b3efecf..69d320b52d426 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -40,7 +40,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -63,7 +63,7 @@ def test_ddp_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -87,7 +87,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -132,7 +132,7 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir): sync_batchnorm=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 4b56e9d389174..73ff8795a086d 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -133,8 +133,8 @@ def test_deepspeed_plugin_string(tmpdir, plugin): fast_dev_run=True, default_root_dir=tmpdir, strategy=plugin if isinstance(plugin, str) else plugin() ) - assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) - assert trainer.accelerator.training_type_plugin.parallel_devices == [torch.device("cpu")] + assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin) + assert trainer.training_type_plugin.parallel_devices == [torch.device("cpu")] @RunIf(deepspeed=True) @@ -147,7 +147,7 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed") - plugin = trainer.accelerator.training_type_plugin + plugin = trainer.training_type_plugin assert isinstance(plugin, DeepSpeedPlugin) assert plugin.parallel_devices == [torch.device("cpu")] assert plugin.config == deepspeed_config @@ -169,7 +169,7 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed", amp_backend=amp_backend, precision=precision ) - assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) + assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin) assert isinstance(trainer.training_type_plugin.precision_plugin, DeepSpeedPrecisionPlugin) assert trainer.training_type_plugin.precision_plugin.precision == precision @@ -235,8 +235,8 @@ def train_dataloader(self): class AssertCallback(Callback): def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None: - assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) - config = trainer.accelerator.training_type_plugin.config + assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin) + config = trainer.training_type_plugin.config # int value overrides auto mode expected_value = value if isinstance(value, int) else 1 @@ -688,8 +688,8 @@ class TestCallback(Callback): def on_train_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int ) -> None: - original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin - current_deepspeed_plugin = trainer.accelerator.training_type_plugin + original_deepspeed_plugin = initial_trainer.training_type_plugin + current_deepspeed_plugin = trainer.training_type_plugin assert isinstance(original_deepspeed_plugin, DeepSpeedPlugin) assert isinstance(current_deepspeed_plugin, DeepSpeedPlugin) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index e3b7e4986d9fb..9d7a72507b273 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -37,7 +37,7 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v def test_sharded_ddp_choice(tmpdir, strategy, expected): """Test to ensure that plugin is correctly chosen.""" trainer = Trainer(fast_dev_run=True, strategy=strategy) - assert isinstance(trainer.accelerator.training_type_plugin, expected) + assert isinstance(trainer.training_type_plugin, expected) @RunIf(min_gpus=1, fairscale=True) @@ -47,7 +47,7 @@ def test_sharded_ddp_choice(tmpdir, strategy, expected): def test_ddp_choice_sharded_amp(tmpdir, strategy, expected): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy) - assert isinstance(trainer.accelerator.training_type_plugin, expected) + assert isinstance(trainer.training_type_plugin, expected) @RunIf(skip_windows=True, fairscale=True)