diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 0dd11baee769f..7568a82b3058e 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -15,7 +15,7 @@ import os import platform import time -from typing import Type, Union +from typing import Type import pytest import torch @@ -32,10 +32,8 @@ @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=1, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -45,11 +43,9 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=1, precision=16, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -59,10 +55,8 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=2, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -73,11 +67,9 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=2, precision=16, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -88,11 +80,9 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=2, precision=16, - accelerator='ddp_spawn', - plugin='ddp_sharded', model_cls=SeedTrainLoaderModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -105,11 +95,9 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): ) @DDPLauncher.run("--accelerator ddp --gpus 2 --precision 32") def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): - plugin_parity_test( + sharded_parity_test( gpus=args.gpus, precision=args.precision, - accelerator=args.accelerator, - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -121,11 +109,9 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): ) @DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16") def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): - plugin_parity_test( + sharded_parity_test( gpus=args.gpus, precision=args.precision, - accelerator=args.accelerator, - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -138,10 +124,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ Ensures same results using multiple optimizers across multiple GPUs """ - plugin_parity_test( - plugin=DDPShardedPlugin(), + sharded_parity_test( gpus=2, - accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -155,10 +139,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ Ensures using multiple optimizers across multiple GPUs with manual optimization """ - plugin_parity_test( - plugin=DDPShardedPlugin(), + sharded_parity_test( gpus=2, - accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -273,9 +255,7 @@ def plugin_parity_test( Args: model_cls: Model class to use for test. - plugin: Plugin to parity test. seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. - accelerator: Accelerator type for test. gpus: Number of GPUS to enable. precision: Whether to use AMP or normal FP32 training. max_percent_speed_diff: The maximum speed difference compared to normal DDP training. @@ -293,7 +273,7 @@ def plugin_parity_test( max_epochs=1, gpus=gpus, precision=precision, - accelerator=accelerator, + accelerator='ddp_spawn', ) max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda) @@ -307,8 +287,7 @@ def plugin_parity_test( max_epochs=1, gpus=gpus, precision=precision, - accelerator=accelerator, - plugins=[plugin], + accelerator='ddp_sharded_spawn', ) max_memory_custom, custom_model_time = record_ddp_fit_model_stats( diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index a97edb21e504d..2ec118303d153 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -1,25 +1,4 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.accelerators.legacy.accelerator import Accelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.cpu_accelerator import CPUAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.ddp2_accelerator import DDP2Accelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.ddp_accelerator import DDPAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.ddp_cpu_hpc_accelerator import DDPCPUHPCAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.ddp_cpu_spawn_accelerator import DDPCPUSpawnAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.ddp_hpc_accelerator import DDPHPCAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.ddp_spawn_accelerator import DDPSpawnAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.dp_accelerator import DataParallelAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.horovod_accelerator import HorovodAccelerator # noqa: F401 -from pytorch_lightning.accelerators.legacy.tpu_accelerator import TPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.tpu import TPUAccelerator diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index b6c60bb1a7eee..1fa95ef4c13b5 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -34,7 +34,7 @@ SingleDevicePlugin, SingleTPUPlugin, TPUHalfPrecisionPlugin, - TPUSpawnPlugin, + TPUSpawnPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin, ) from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus @@ -116,16 +116,12 @@ def __init__( # override dist backend when using tpus if self.on_tpu: self.distributed_backend = "tpu" - self.use_tpu = True # init flags for SLURM+DDP to work self.world_size = 1 self.interactive_ddp_procs = [] self.global_rank = 0 - # NVIDIA setup - # self.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids) - # benchmarking # TODO: should this be moved to GPU accelerator? torch.backends.cudnn.benchmark = self.benchmark @@ -138,9 +134,6 @@ def __init__( # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) - # TODO: move this to TPU accelerator/plugin - self.on_colab_kaggle = os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE") - self.replace_sampler_ddp = replace_sampler_ddp @property @@ -256,23 +249,21 @@ def select_training_type_plugin(self): use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks - # use_ddp_sharded = self.distributed_backend == "ddp_sharded" - # use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn" + use_ddp_sharded = self.distributed_backend == "ddp_sharded" + use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn" - if self.on_tpu: - ddp_plugin_cls = TPUSpawnPlugin - - # ddp script mode uses the same flags as TE # TODO: decouple from TE + # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False - # fixme - # if use_ddp_sharded: - # ddp_plugin_cls = DDPShardedPlugin - # elif use_ddp_sharded_spawn: - # ddp_plugin_cls = DDPSpawnShardedPlugin - if use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: + if self.on_tpu: + ddp_plugin_cls = TPUSpawnPlugin + elif use_ddp_sharded: + ddp_plugin_cls = DDPShardedPlugin + elif use_ddp_sharded_spawn: + ddp_plugin_cls = DDPSpawnShardedPlugin + elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin @@ -328,6 +319,8 @@ def select_cluster_environment(self): return env def set_distributed_mode(self): + if isinstance(self.distributed_backend, Accelerator): + return if self.distributed_backend is None: if self.has_horovodrun(): @@ -355,27 +348,27 @@ def set_distributed_mode(self): # special case with TPUs elif self.distributed_backend == 'tpu': self._device_type = DeviceType.TPU - # set all other requested distrib. types adn if it was not set in the + # set all other requested distrib. types and if it was not set in the elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) # unless you request explicitly for CPU and some GPU are available use them _on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend - if (self.num_gpus > 0 and not _on_cpu): + if self.num_gpus > 0 and not _on_cpu: self._device_type = DeviceType.GPU - _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + # _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) # DP and DDP2 cannot run without GPU - if (self.num_gpus == 0 and self._distrib_type in _distrib_types): - rank_zero_warn( - 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' - ) - # todo: in some cases it yield in comarison None and int - if ((self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1)): - self._distrib_type = DistributedType.DDP - else: - rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') - self._distrib_type = None + # if (self.num_gpus == 0 and self._distrib_type in _distrib_types): + # rank_zero_warn( + # 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' + # ) + # # todo: in some cases it yield in comarison None and int + # if ((self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1)): + # self._distrib_type = DistributedType.DDP + # else: + # rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') + # self._distrib_type = None # for DDP overwrite nb processes by requested GPUs if ( diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 833d5e1cb2a9a..f01cecac1615a 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,17 +1,22 @@ +import logging +import os + import torch from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +log = logging.getLogger(__name__) + class GPUAccelerator(Accelerator): def setup(self, trainer, model): if "cuda" not in str(self.root_device): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") + self.set_nvidia_flags() torch.cuda.set_device(self.root_device) model.to(self.root_device) - return super().setup(trainer, model) def on_train_start(self): @@ -25,3 +30,11 @@ def on_train_end(self): # clean up memory with torch.cuda.device(self.root_device): torch.cuda.empty_cache() + + @staticmethod + def set_nvidia_flags(): + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) + devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) + log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ec44a1eeb416b..d39e600820735 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -188,6 +188,7 @@ def _run_early_stopping_check(self, trainer, pl_module): return # short circuit if metric not present current = logs.get(self.monitor) + should_stop = False # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) @@ -204,5 +205,5 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process if any world process decides to stop - should_stop = trainer.accelerator_backend.early_stopping_should_stop(pl_module) + should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop) trainer.should_stop = should_stop diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 25d6f39760b8a..100c84e1d9bdc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -458,7 +458,7 @@ def __resolve_ckpt_dir(self, trainer, pl_module): else f"version_{trainer.logger.version}" ) - version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name)) + version, name = trainer.training_type_plugin.broadcast((version, trainer.logger.name)) ckpt_path = os.path.join( save_dir, str(name), version, "checkpoints" diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 965dba8ad3a30..14ab52c3c6fba 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -276,6 +276,7 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx.") accelerator = self.trainer.accelerator_backend + training_type_plugin = self.trainer.training_type_plugin self._results.log( name, @@ -291,7 +292,7 @@ def log( sync_dist, sync_dist_op, sync_dist_group, - accelerator.sync_tensor, + training_type_plugin.reduce, self._current_dataloader_idx, self.device, ) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 395ae2f5ca168..e75a5568aae0f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -137,8 +137,9 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) - elif trainer.amp_backend is not None: - trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure) + # elif trainer.amp_backend is not None: + # # TODO: Adapt for new optimizer structure + # trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure) else: with trainer.profiler.profile(profiler_name): diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 0990b547907e7..91cebaee2bd4c 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -4,6 +4,7 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 @@ -13,6 +14,8 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 __all__ = [ diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index 2139f5bac0020..41af4fe84c7f0 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -26,8 +26,11 @@ def master_address(self): def master_port(self): pass - def world_size(self): + def world_size(self) -> int: return self._world_size - def local_rank(self): + def local_rank(self) -> int: + pass + + def node_rank(self) -> int: pass diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 2d2c59d934d62..27115f116b862 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -32,7 +32,7 @@ def master_address(self): else: root_node = "127.0.0.1" - root_node = self._resolve_root_node_address(root_node) + root_node = self.resolve_root_node_address(root_node) os.environ["MASTER_ADDR"] = root_node log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") return root_node @@ -70,7 +70,10 @@ def world_size(self): def local_rank(self): return int(os.environ['SLURM_LOCALID']) - def _resolve_root_node_address(self, root_node): + def node_rank(self): + return int(os.environ['SLURM_NODEID']) + + def resolve_root_node_address(self, root_node): if '[' in root_node: name, numbers = root_node.split('[', maxsplit=1) number = numbers.split(',', maxsplit=1)[0] diff --git a/pytorch_lightning/plugins/legacy/plugin_connector.py b/pytorch_lightning/plugins/legacy/plugin_connector.py index 22f97bf8b77f3..95ec73f7dd80e 100644 --- a/pytorch_lightning/plugins/legacy/plugin_connector.py +++ b/pytorch_lightning/plugins/legacy/plugin_connector.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Union, Sequence +from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.plugins.legacy.apex import ApexPlugin from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin @@ -26,22 +27,22 @@ class PluginConnector: - def __init__(self, trainer): + def __init__(self, trainer, plugins: Optional[Union[str, list]] = None): self.trainer = trainer - self.plugins = [] - self.ddp_plugin = DDPPlugin() + self.plugins = plugins or [] self.cloud_environment = None - - def on_trainer_init(self, plugins: Optional[Union[str, list]]): - self.plugins = plugins - if self.plugins is None: - self.plugins = [] + # self.ddp_plugin = DDPPlugin() self.plugins = self._convert_str_custom_plugins(self.plugins) - self.plugins = self._append_required_plugins(self.plugins) - self.__attach_ddp() + + # TODO: plugin dependencies + # self.plugins = self._append_required_plugins(self.plugins) + self.__attach_cluster() - self.__attach_amp() - self.__attach_apex() + + # TODO: attach custom training type and precision plugins + # self.__attach_ddp() + # self.__attach_amp() + # self.__attach_apex() def __attach_amp(self): amp_plugin = self.__attach_plugin(NativeAMPPlugin) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 21dec5bc5ccda..1d1f203afa38a 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -10,3 +10,5 @@ from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin +from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index f45c3dcb93bb6..335f65b3e3fbb 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -45,6 +45,7 @@ def setup(self, model): self.global_rank = hvd.rank() self.local_rank = hvd.local_rank() + self.world_size = hvd.size() rank_zero_only.rank = self.global_rank self.model_to_device() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 89f2329512e5e..bda5d161da33b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -64,18 +64,6 @@ def barrier(self, name: Optional[str] = None) -> None: def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" - # TODO method this is currently unused. Check after complete refactors are pushed - def set_nvidia_flags(self, is_slurm_managing_tasks: bool, device_ids: Optional[Sequence]) -> None: - if device_ids is None: - return - - # set the correct cuda visible devices (using pci order) - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) - devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids) - if self.lightning_module is not None: - log.info(f"LOCAL_RANK: {self.lightning_module.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: """Reduce the early stopping decision across all possibly spawned processes""" return should_stop diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 001b0b9ed3e0d..8d1a482deff15 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -73,7 +73,7 @@ def restore_weights(self, model: LightningModule) -> None: self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) # wait for all to catch up - self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') + self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights') # clear cache after restore if self.trainer._device_type == DeviceType.GPU: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index f6700187c3912..812584b3ab5e4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -31,8 +31,10 @@ class LoggerConnector: - def __init__(self, trainer): + + def __init__(self, trainer, log_gpu_memory): self.trainer = trainer + self.log_gpu_memory = log_gpu_memory self._callback_metrics = MetricsHolder() self._evaluation_callback_metrics = MetricsHolder(to_float=True) self._logged_metrics = MetricsHolder() @@ -217,8 +219,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): and global_step for the rest. """ # add gpu memory - if self.trainer._device_type == DeviceType.GPU and self.trainer.log_gpu_memory: - mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory) + if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: + mem_map = memory.get_memory_profile(self.log_gpu_memory) metrics.update(mem_map) # add norms diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 673e8765ed51f..68fa7354bdcde 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -36,8 +36,6 @@ def copy_trainer_model_properties(self, model): m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing - m.tpu_local_core_rank = self.trainer.tpu_local_core_rank - m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision def get_model(self): @@ -45,5 +43,5 @@ def get_model(self): def _get_reference_model(self, model): if self.trainer.accelerator_backend: - return self.trainer.accelerator_backend.get_reference_model(model) + return self.trainer.accelerator_backend.lightning_module return model diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index ad860c0b154b2..02552dd67de26 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -1,14 +1,8 @@ import os -import re import signal from subprocess import call -import torch -import torch.distributed as torch_distrib - from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import DeviceType, DistributedType -from pytorch_lightning.utilities.distributed import rank_zero_info class SLURMConnector: @@ -16,57 +10,6 @@ class SLURMConnector: def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, num_gpu_nodes): - self.configure_slurm_ddp(num_gpu_nodes) - - def configure_slurm_ddp(self, num_gpu_nodes): - self.trainer.is_slurm_managing_tasks = False - - # extract SLURM flag vars - # whenever we have the correct number of tasks, we let slurm manage processes - # otherwise we launch the required number of processes - if self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): - self.trainer.num_requested_gpus = self.trainer.num_gpus * num_gpu_nodes - self.trainer.num_slurm_tasks = 0 - try: - self.trainer.num_slurm_tasks = int(os.environ['SLURM_NTASKS']) - self.trainer.is_slurm_managing_tasks = self.trainer.num_slurm_tasks == self.trainer.num_requested_gpus - - # enable slurm cpu - if self.trainer.num_requested_gpus == 0: - self.trainer.is_slurm_managing_tasks = self.trainer.num_slurm_tasks == self.trainer.num_processes - - # in interactive mode we don't manage tasks - job_name = os.environ['SLURM_JOB_NAME'] - if job_name == 'bash': - self.trainer.is_slurm_managing_tasks = False - # todo: specify the possible exception - except Exception: - # likely not on slurm, so set the slurm managed flag to false - self.trainer.is_slurm_managing_tasks = False - - # used for tests only, set this flag to simulate slurm managing a task - should_fake = os.environ.get('FAKE_SLURM_MANAGING_TASKS') - if should_fake and int(should_fake): - self.trainer.is_slurm_managing_tasks = True - - # notify user the that slurm is managing tasks - if self.trainer.is_slurm_managing_tasks: - rank_zero_info('Multi-processing is handled by Slurm.') - - # todo: the same function as slurm_environment.py `_resolve_root_node_address` - def resolve_root_node_address(self, root_node): - if '[' in root_node: - name, numbers = root_node.split('[', maxsplit=1) - number = numbers.split(',', maxsplit=1)[0] - if '-' in number: - number = number.split('-')[0] - - number = re.sub('[^0-9]', '', number) - root_node = name + number - - return root_node - def register_slurm_signal_handlers(self): # see if we're using slurm (not interactive) on_slurm = False @@ -112,48 +55,3 @@ def term_handler(self, signum, frame): # Todo: required argument `signum` is not used # Todo: required argument `frame` is not used log.info("bypassing sigterm") - - # todo: this is the same func as slurm_environment.py `master_port` - def connect_ddp(self, global_rank: int, world_size: int) -> None: - """ - Sets up environment variables necessary for pytorch distributed communications - based on slurm environment. - """ - # use slurm job id for the port number - # guarantees unique ports across jobs from same grid search - default_port = os.environ.get("SLURM_JOB_ID") - if default_port: - # use the last 4 numbers in the job id as the id - default_port = default_port[-4:] - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - else: - default_port = 12910 - - # if user gave a port number, use that one instead - if "MASTER_PORT" in os.environ: - default_port = os.environ["MASTER_PORT"] - else: - os.environ["MASTER_PORT"] = str(default_port) - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - - # figure out the root node addr - root_node = os.environ.get("SLURM_NODELIST") - if root_node: - root_node = root_node.split(" ")[0] - else: - root_node = "127.0.0.1" - - root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node) - os.environ["MASTER_ADDR"] = root_node - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - - torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo" - - if not torch.distributed.is_initialized(): - log.info( - f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" - ) - torch_distrib.init_process_group( - torch_backend, rank=global_rank, world_size=world_size - ) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5031357b41615..eeda8ab81bdf3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -62,7 +62,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: # ddp_spawn + num_workers > 0 don't mix! tell the user is_dataloader = isinstance(dataloader, DataLoader) - using_spawn = self.distributed_backend == "ddp_spawn" + using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn" if is_dataloader and not on_windows: if dataloader.num_workers > 0 and using_spawn: rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!' @@ -92,8 +92,9 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - need_dist_sampler = self.require_distributed_sampler and not isinstance(dataloader.sampler, DistributedSampler) - if self.replace_sampler_ddp and need_dist_sampler: + is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) + if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' @@ -314,7 +315,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: dataloader = self._flatten_dl_only(dataloader) if self.accelerator_backend is not None: - self.accelerator_backend.barrier('get_dataloaders') + self.training_type_plugin.barrier('get_dataloaders') return dataloader def _flatten_dl_only(self, dataloaders): diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 2aaed17e9818c..20438f427d315 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -141,27 +141,6 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers - def reinit_scheduler_properties(self, optimizers: list, schedulers: list): - # Reinitialize optimizer.step properties added by schedulers - for scheduler in schedulers: - scheduler = scheduler['scheduler'] - - for optimizer in optimizers: - # check that we dont mix users optimizers and schedulers - if scheduler.optimizer == optimizer: - # Find the mro belonging to the base lr scheduler class - for i, mro in enumerate(scheduler.__class__.__mro__): - if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau): - idx = i - state = scheduler.state_dict() - else: - state = None - - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) - if state is not None: - scheduler.load_state_dict(state) - - class _MockOptimizer(Optimizer): """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from `configure_optimizers`. diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 760a621db6914..39dcbc6c7c3e0 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,10 +15,11 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import cast, List, Optional, Type, TypeVar, Union +from typing import cast, List, Optional, Type, TypeVar, Union, Any -from pytorch_lightning.accelerators.legacy.accelerator import Accelerator -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator_connector import BackendConnector +from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint, EarlyStopping from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -41,6 +42,9 @@ if _HOROVOD_AVAILABLE: import horovod.torch as hvd +from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.loggers.base import LightningLoggerBase +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger class TrainerProperties(ABC): @@ -59,14 +63,104 @@ class TrainerProperties(ABC): _default_root_dir: str _weights_save_path: str accelerator_backend: Accelerator - logger: LightningLoggerBase - model_connector: ModelConnector - checkpoint_connector: CheckpointConnector - callbacks: List[Callback] num_nodes: int num_processes: int + accelerator_connector: BackendConnector _lightning_optimizers = None + @property + def accelerator(self): + return self.accelerator_connector.accelerator + + @property + def accelerator_backend(self): + # for backward compatibility + return self.accelerator + + @property + def distributed_backend(self): + # for backward compatibility + return self.accelerator_connector.distributed_backend + + @property + def training_type_plugin(self): + return self.accelerator.training_type_plugin + + @property + def precision_plugin(self): + return self.accelerator.precision_plugin + + @property + def global_rank(self): + return self.accelerator.training_type_plugin.global_rank + + @property + def local_rank(self): + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + + @property + def node_rank(self): + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + + @property + def world_size(self): + # some training types define a world size + return getattr(self.accelerator.training_type_plugin, "world_size", 1) + + @property + def on_gpu(self): + return self.accelerator_connector.on_gpu + + @property + def on_tpu(self): + return self.accelerator_connector.on_tpu + + @property + def use_dp(self): + return self.accelerator_connector.use_dp + + @property + def use_ddp(self): + return self.accelerator_connector.use_ddp + + @property + def use_ddp2(self): + return self.accelerator_connector.use_ddp2 + + @property + def use_horovod(self): + return self.accelerator_connector.use_horovod + + @property + def use_tpu(self): + return self.accelerator_connector.on_tpu + + @property + def _distrib_type(self): + return self.accelerator_connector._distrib_type + + @property + def _device_type(self): + return self.accelerator_connector._device_type + + @property + def num_nodes(self): + return self.accelerator_connector.num_nodes + + @property + def num_processes(self): + return self.accelerator_connector.num_processes + + @property + def root_gpu(self): + return self.accelerator_connector.root_gpu + + @property + def data_parallel_device_ids(self): + return self.accelerator_connector.parallel_device_ids + @property def log_dir(self): if self.checkpoint_callback is not None: @@ -171,12 +265,13 @@ def match_env_arguments(cls) -> Namespace: def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: return add_argparse_args(cls, parent_parser) + @property + def gpus(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.gpus + @property def num_gpus(self) -> int: - gpus = self.data_parallel_device_ids - if gpus is None: - return 0 - return len(gpus) + return self.accelerator_connector.num_gpus @property def data_parallel(self) -> bool: @@ -203,7 +298,7 @@ def disable_validation(self) -> bool: @property def enable_validation(self) -> bool: """ Check if we should run validation during training. """ - model_ref = self.model_connector.get_model() + model_ref = self.get_model() val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 return val_loop_enabled @@ -264,8 +359,31 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: def save_checkpoint(self, filepath, weights_only: bool = False): self.checkpoint_connector.save_checkpoint(filepath, weights_only) + @property + def model(self) -> Any: + """ + The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + To access the pure LightningModule, use + :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. + """ + return self.accelerator.model + + @model.setter + def model(self, model: Any): + """ + Setter for the model, pass-through to accelerator and plugin where the model reference is stored. + Used by the Tuner to reset the state of Trainer and Accelerator. + + Args: + model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending + on the backend. + """ + self.accelerator.model = model + def get_model(self): - return self.model_connector.get_model() + # TODO: rename this to lightning_module (see training type plugin) + # backward compatible + return self.lightning_module @property def lightning_optimizers(self): @@ -273,6 +391,47 @@ def lightning_optimizers(self): self.convert_to_lightning_optimizers() return self._lightning_optimizers + @property + def lightning_module(self): + return self.training_type_plugin.lightning_module + + @property + def optimizers(self): + return self.accelerator.optimizers + + @optimizers.setter + def optimizers(self, new_optims): + self.accelerator.optimizers = new_optims + + @property + def lr_schedulers(self): + return self.accelerator.lr_schedulers + + @lr_schedulers.setter + def lr_schedulers(self, new_schedulers): + self.accelerator.lr_schedulers = new_schedulers + + @property + def optimizer_frequencies(self): + return self.accelerator.optimizer_frequencies + + @optimizer_frequencies.setter + def optimizer_frequencies(self, new_freqs): + self.accelerator.optimizer_frequencies = new_freqs + + @property + def amp_backend(self): + return self.accelerator.amp_backend + + @property + def precision(self): + return self.accelerator.precision + + @property + def scaler(self): + return self.accelerator.scaler + + # TODO: refactor this so that it can be done in LightningOptimizer def __getstate__(self): # remove lightning_optimizers self._lightning_optimizers = None @@ -289,8 +448,9 @@ def require_distributed_sampler(self): @property def distributed_sampler_kwargs(self): if self.accelerator_backend is not None: - return self.accelerator_backend.distributed_sampler_kwargs + return self.training_type_plugin.distributed_sampler_kwargs + # TODO: make sure the cases below are handled by the training_type_plugin if self._device_type == DeviceType.TPU: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ba34c49581038..5cdfa5021acb8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -23,11 +23,12 @@ from torch.utils.data import DataLoader from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.legacy.accelerator import Accelerator -from pytorch_lightning.accelerators.legacy.accelerator_connector import AcceleratorConnector +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.accelerators.accelerator_connector import BackendConnector from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins.legacy.plugin_connector import PluginConnector @@ -42,7 +43,6 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector -from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector @@ -54,6 +54,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -62,12 +63,13 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.model_utils import is_overridden # warnings to ignore in trainer warnings.filterwarnings( - 'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead' + "ignore", message="torch.distributed.reduce_op is deprecated, " "please use torch.distributed.ReduceOp instead" ) +os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" class Trainer( @@ -114,7 +116,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, sync_batchnorm: bool = False, precision: int = 32, - weights_summary: Optional[str] = 'top', + weights_summary: Optional[str] = "top", weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, @@ -294,20 +296,36 @@ def __init__( reload when reaching the minimum length of datasets. """ super().__init__() - self._device_type = DeviceType.CPU - self._distrib_type = None self._running_stage = None self._predicting = False + distributed_backend = distributed_backend or accelerator + # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) - self.accelerator_connector = AcceleratorConnector(self) - self.logger_connector = LoggerConnector(self) + self.plugin_connector = PluginConnector(self, plugins) + self.accelerator_connector = BackendConnector( + num_processes, + tpu_cores, + distributed_backend, + auto_select_gpus, + gpus, + num_nodes, + sync_batchnorm, + benchmark, + replace_sampler_ddp, + deterministic, + precision, + amp_backend, + amp_level, + self.plugin_connector.cloud_environment + ) + self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) - self.precision_connector = PrecisionConnector(self) + # self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) @@ -315,13 +333,11 @@ def __init__( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self, multiple_trainloader_mode) - self.plugin_connector = PluginConnector(self) # training state - self.model = None + self.weights_summary = weights_summary self.shown_warnings = set() # init callbacks @@ -352,22 +368,6 @@ def __init__( gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan ) - # init accelerator related flags - self.accelerator_connector.on_trainer_init( - num_processes, - tpu_cores, - accelerator, - distributed_backend, - auto_select_gpus, - gpus, - num_nodes, - log_gpu_memory, - sync_batchnorm, - benchmark, - replace_sampler_ddp, - deterministic, - ) - # init train loop related flags # TODO: remove in 1.3.0 if automatic_optimization is None: @@ -413,10 +413,11 @@ def __init__( ) # set precision - self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) + # self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # last thing are the plugins which override whatever the trainer used by default - self.plugin_connector.on_trainer_init(plugins) + # TODO: probably not needed anymore after refactor + # self.plugin_connector.on_trainer_init(plugins) # Callback system self.on_init_end() @@ -458,45 +459,53 @@ def fit( # bookkeeping # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) + self.testing = os.environ.get("PL_TESTING_MODE", self.testing) # ---------------------------- # SET UP TRAINING # ---------------------------- - self.accelerator_backend = self.accelerator_connector.select_accelerator() - self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator_backend.setup(model) - - # ---------------------------- - # INSPECT THESE FOR MAIN LOOPS - # ---------------------------- - # assign training and eval functions... inspect these to see the train and eval loops :) - self.accelerator_backend.train_loop = self.train - self.accelerator_backend.validation_loop = self.run_evaluation - self.accelerator_backend.test_loop = self.run_evaluation + # self.accelerator_backend = self.accelerator_connector.select_accelerator() + self.accelerator_backend.setup(self, model) + self.train_loop.setup_training(model) # ---------------------------- # TRAIN # ---------------------------- # hook - self.call_hook('on_fit_start') - results = self.accelerator_backend.train() + self.call_hook("on_fit_start") + + # plugin will setup training (e.g. ddp will launch child processes) + # TODO: the old setup is now called "pre_training", where should this hook be called now? + self.call_hook("on_before_accelerator_backend_setup", model) + self.training_type_plugin.pre_training() + self.precision_plugin.pre_training() + + self.call_setup_hook(self.lightning_module) + + # double dispatch: let the plugin initiate the training/test loop. + if self.testing: + self.training_type_plugin.start_testing(self) + else: + self.training_type_plugin.start_training(self) + + self.precision_plugin.post_training() + self.training_type_plugin.post_training() self.accelerator_backend.teardown() + results = self.training_type_plugin.results # ---------------------------- # POST-Training CLEAN UP # ---------------------------- # hook - self.call_hook('on_fit_end') + self.call_hook("on_fit_end") # hook - self.teardown('fit') - if self.is_function_implemented('teardown'): - model.teardown('fit') + self.teardown("fit") + if self.is_function_implemented("teardown"): + model.teardown("fit") # return 1 when finished # used for testing or when we need to know that training succeeded - if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED @@ -527,7 +536,44 @@ def _set_wide_running_stage(self, stage): self._running_stage = stage + def pre_training_routine(self): + # wait for all to join if on distributed + self.accelerator.training_type_plugin.barrier("setup_training") + + # register auto-resubmit when on SLURM + self.slurm_connector.register_slurm_signal_handlers() + + # -------------------------- + # Pre-train + # -------------------------- + # on pretrain routine start + ref_model = self.get_model() + + self.on_pretrain_routine_start(ref_model) + if self.is_function_implemented("on_pretrain_routine_start"): + ref_model.on_pretrain_routine_start() + + # print model summary + if self.is_global_zero and self.weights_summary is not None and not self.testing: + if self.weights_summary in ModelSummary.MODES: + ref_model.summarize(mode=self.weights_summary) + else: + raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) + + # restore training and model before hpc is called + self.checkpoint_connector.restore_weights(ref_model) + + # on pretrain routine end + self.on_pretrain_routine_end(ref_model) + if self.is_function_implemented("on_pretrain_routine_end"): + ref_model.on_pretrain_routine_end() + def train(self): + self.pre_training_routine() + + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() + self.run_sanity_check(self.get_model()) # set stage for logging @@ -563,7 +609,7 @@ def train(self): return # update LR schedulers - self.optimizer_connector.update_learning_rates(interval='epoch') + self.optimizer_connector.update_learning_rates(interval="epoch") # early stopping met_min_epochs = epoch >= self.min_epochs - 1 @@ -572,14 +618,18 @@ def train(self): if self.should_stop: if met_min_epochs and met_min_steps: return - log.info( - 'Trainer was signaled to stop but required minimum epochs' - f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' - ' not been met. Training will continue...' - ) + else: + log.info( + "Trainer was signaled to stop but required minimum epochs" + f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" + " not been met. Training will continue..." + ) + + # hook + self.train_loop.on_train_end() except KeyboardInterrupt: - rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') + rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press ctrl+c many times... only shutdown once if not self.interrupted: @@ -698,6 +748,9 @@ def track_output_for_epoch_end(self, outputs, output): return outputs def run_test(self): + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() + # only load test dataloader for testing # self.reset_test_dataloader(ref_model) with self.profiler.profile("run_test_evaluation"): @@ -716,7 +769,7 @@ def run_test(self): return eval_loop_results def run_sanity_check(self, ref_model): - using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) + using_val_step = ref_model.val_dataloader is not None and is_overridden("validation_step", ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) @@ -750,7 +803,7 @@ def test( self, model: Optional[LightningModule] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = "best", verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ): @@ -784,18 +837,18 @@ def test( # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( - 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' + "You cannot pass test_dataloaders to trainer.test if you supply a datamodule" ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') + self.data_connector.attach_datamodule(model or self.get_model(), datamodule, "test") if model is not None: results = self.__test_given_model(model, test_dataloaders) else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) - self.teardown('test') + self.teardown("test") self._set_wide_running_stage(None) @@ -805,7 +858,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() # if user requests the best checkpoint but we don't have it, error - if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: + if ckpt_path == "best" and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' ) @@ -813,20 +866,20 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # load best weights if ckpt_path is not None: # ckpt_path is 'best' so load the best model - if ckpt_path == 'best': + if ckpt_path == "best": ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: rank_zero_warn( - f'.test() found no path for the best weights, {ckpt_path}. Please ' - f'specify a path for a checkpoint .test(ckpt_path=PATH)' + f".test() found no path for the best weights, {ckpt_path}. Please " + f"specify a path for a checkpoint .test(ckpt_path=PATH)" ) return {} - if self.accelerator_backend is not None and not self._device_type == DeviceType.TPU: - self.accelerator_backend.barrier() + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt['state_dict']) + model.load_state_dict(ckpt["state_dict"]) # attach dataloaders if test_dataloaders is not None: @@ -835,16 +888,15 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path self.testing = True - os.environ['PL_TESTING_MODE'] = '1' - self.model = model + os.environ["PL_TESTING_MODE"] = "1" results = self.fit(model) self.testing = False - del os.environ['PL_TESTING_MODE'] + del os.environ["PL_TESTING_MODE"] # teardown - if self.is_function_implemented('teardown'): + if self.is_function_implemented("teardown"): model_ref = self.get_model() - model_ref.teardown('test') + model_ref.teardown("test") return results @@ -857,13 +909,12 @@ def __test_given_model(self, model, test_dataloaders): # run test # sets up testing so we short circuit to eval self.testing = True - self.model = model results = self.fit(model) self.testing = False # teardown - if self.is_function_implemented('teardown'): - model.teardown('test') + if self.is_function_implemented("teardown"): + model.teardown("test") return results @@ -952,7 +1003,7 @@ def tune( def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = "test" if self.testing else "fit" if self.datamodule is not None: called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit if not called: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 85c5758ec27be..695741ed3cd22 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -18,6 +18,7 @@ import numpy as np import torch +from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary @@ -102,13 +103,6 @@ def should_skip_training(self): return False def on_train_start(self): - # clear cache before training - if self.trainer._device_type == DeviceType.GPU and self.trainer.root_gpu is not None: - # use context because of: - # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 - with torch.cuda.device(f"cuda:{self.trainer.root_gpu}"): - torch.cuda.empty_cache() - # hook self.trainer.call_hook("on_train_start") @@ -116,8 +110,8 @@ def on_train_start(self): self.trainer.profile_connector.on_train_start(self.trainer) def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): - # bind logger and other properties - self.trainer.model_connector.copy_trainer_model_properties(model) + # # bind logger and other properties + # self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, "hparams"): @@ -141,11 +135,7 @@ def setup_training(self, model: LightningModule): # -------------------------- # Setup?? # -------------------------- - ref_model = self.trainer.get_model() - - # set the ranks and devices - self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank - self.trainer.accelerator_backend.dist.device = ref_model.device + ref_model = model # give model convenience properties ref_model.trainer = self.trainer @@ -166,36 +156,6 @@ def setup_training(self, model: LightningModule): self.trainer.logger.log_graph(ref_model) self.trainer.logger.save() - # wait for all to join if on distributed - self.trainer.accelerator_backend.barrier("setup_training") - - # register auto-resubmit when on SLURM - self.trainer.slurm_connector.register_slurm_signal_handlers() - - # -------------------------- - # Pre-train - # -------------------------- - # on pretrain routine start - self.trainer.on_pretrain_routine_start(ref_model) - if self.trainer.is_function_implemented("on_pretrain_routine_start"): - ref_model.on_pretrain_routine_start() - - # print model summary - if self.trainer.is_global_zero and not self.trainer.testing: - ref_model.summarize(mode=self.trainer.weights_summary) - - # track model now. - # if cluster resets state, the model will update with the saved weights - self.trainer.model = model - - # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights(model) - - # on pretrain routine end - self.trainer.on_pretrain_routine_end(ref_model) - if self.trainer.is_function_implemented("on_pretrain_routine_end"): - ref_model.on_pretrain_routine_end() - def on_train_end(self): if self._teardown_already_run: return @@ -518,12 +478,15 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator_backend.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + def track_and_norm_grad(self, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() # clip gradients - self.trainer.accelerator_backend.clip_gradients(optimizer) + self.trainer.accelerator_backend.clip_gradients(optimizer, self.trainer.gradient_clip_val) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): @@ -778,8 +741,8 @@ def block_ddp_sync_behaviour(self): context manager with sync behaviour off """ - if self.trainer.accelerator_backend is not None and self.automatic_optimization: - yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour() + if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and self.automatic_optimization: + yield self.trainer.training_type_plugin.block_backward_sync() else: yield None @@ -844,12 +807,14 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") + should_accumulate = self.should_accumulate() + # backward can be called manually in the training loop if isinstance(result, torch.Tensor): - self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs) + self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator_backend.backward( - result.closure_loss, optimizer, opt_idx, *args, **kwargs + result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs ) if not self.should_accumulate(): diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index fbed98ae2baa7..18557ea366f74 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, MutableSequence, Optional, Union +from typing import Any, List, MutableSequence, Optional, Tuple, Union import torch +from typing import Union, Any, List, Optional, Tuple, MutableSequence from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -145,9 +146,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: return gpus -def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int]]) -> Optional[List[int]]: +def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int], Tuple[int, ...]]) -> Optional[List[int]]: assert gpus is not None - if isinstance(gpus, MutableSequence): + if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) # must be an int @@ -176,7 +177,7 @@ def _check_data_type(device_ids: Any) -> None: device_ids: gpus/tpu_cores parameter as passed to the Trainer """ if device_ids is not None and \ - (not isinstance(device_ids, (int, str, MutableSequence)) or isinstance(device_ids, bool)): + (not isinstance(device_ids, (int, str, MutableSequence, tuple)) or isinstance(device_ids, bool)): raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.") diff --git a/tests/accelerators/legacy/test_accelerator_connector.py b/tests/accelerators/legacy/test_accelerator_connector.py index be2cd00ae4a62..a1f9395af6771 100644 --- a/tests/accelerators/legacy/test_accelerator_connector.py +++ b/tests/accelerators/legacy/test_accelerator_connector.py @@ -16,9 +16,14 @@ from unittest import mock import pytest - -from pytorch_lightning import accelerators, Trainer -from pytorch_lightning.accelerators import Accelerator +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.plugins import SingleDevicePlugin, DDPPlugin, DDPSpawnPlugin, DDP2Plugin +from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.utilities import DistributedType @@ -26,81 +31,47 @@ def test_accelerator_choice_cpu(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend, accelerators.CPUAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - - model = BoringModel() trainer = Trainer( fast_dev_run=True, - callbacks=[CB()] ) - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, SingleDevicePlugin) def test_accelerator_choice_ddp_cpu(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSpawnAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', - num_processes=2, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp', gpus=1, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp_spawn(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPSpawnAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp_spawn', gpus=1, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) @mock.patch.dict(os.environ, { @@ -114,11 +85,13 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -145,11 +118,13 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type == DistributedType.DDP2 - assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp2 + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDP2Plugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() @@ -175,11 +150,12 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -204,11 +180,12 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type == DistributedType.DDP2 - assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp2 + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDP2Plugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -232,12 +209,12 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx - + assert trainer.use_ddp + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -263,9 +240,11 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) + assert trainer.use_ddp + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) raise SystemExit() model = BoringModel() @@ -299,9 +278,10 @@ def master_address(self): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, CustomCluster) + assert trainer.use_ddp + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, CustomCluster) raise SystemExit() model = BoringModel() @@ -327,28 +307,26 @@ def on_fit_start(self, trainer, pl_module): @mock.patch('torch.cuda.device_count', return_value=0) def test_custom_accelerator(tmpdir): class Accel(Accelerator): - def init_ddp_connection( - self, - global_rank: int, - world_size: int, - is_slurm_managing_tasks: bool = True) -> None: - pass + pass - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend, Accel) - raise SystemExit() + class Prec(PrecisionPlugin): + pass - model = BoringModel() + class TrainTypePlugin(SingleDevicePlugin): + pass + + accelerator = Accel( + training_type_plugin=TrainTypePlugin(device=torch.device("cpu")), + precision_plugin=Prec(), + ) trainer = Trainer( + accelerator=accelerator, fast_dev_run=True, - accelerator=Accel(), num_processes=2, - callbacks=[CB()] ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, Accel) + assert isinstance(trainer.training_type_plugin, TrainTypePlugin) + assert isinstance(trainer.precision_plugin, Prec) @mock.patch.dict(os.environ, { @@ -362,7 +340,8 @@ def on_fit_start(self, trainer, pl_module): def test_dist_backend_accelerator_mapping(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) raise SystemExit() model = BoringModel() diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 4949d53fc9a50..71747c21bf989 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -44,11 +44,6 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 for dataloader in test_loaders: run_prediction(pretrained_model, dataloader, min_acc=min_acc) - if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN): - # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() - def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, with_hpc: bool = True, min_acc: float = 0.25): @@ -84,10 +79,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, if with_hpc: if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( - pretrained_model - ) + trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ + trainer.init_optimizers(pretrained_model) # test HPC saving trainer.checkpoint_connector.hpc_save(save_dir, logger) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b12d0c2884106..b00ced2fefaf9 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -53,9 +53,9 @@ def test_trainer_callback_system(torch_save): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), + call.on_fit_start(trainer, model), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), - call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), @@ -108,11 +108,11 @@ def test_trainer_callback_system(torch_save): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), + call.on_fit_start(trainer, model), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), - call.on_fit_start(trainer, model), - call.on_pretrain_routine_start(trainer, model), - call.on_pretrain_routine_end(trainer, model), + # call.on_pretrain_routine_start(trainer, model), + # call.on_pretrain_routine_end(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 4a0f0499b20e8..c28e1bdb8d658 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,21 +13,24 @@ # limitations under the License. import pickle from argparse import ArgumentParser +from unittest import mock +from unittest.mock import MagicMock, PropertyMock from typing import Any, Dict -from unittest.mock import MagicMock import pytest import torch from pytorch_lightning import LightningDataModule, Trainer -from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.model_helpers import is_overridden from tests.base import BoringDataModule, BoringModel from tests.base.develop_utils import reset_seed -def test_can_prepare_data(tmpdir): +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) +def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() @@ -37,33 +40,36 @@ def test_can_prepare_data(tmpdir): # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True - trainer.local_rank = 0 + + local_rank.return_value = 0 + assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) - trainer.local_rank = 1 + local_rank.return_value = 1 + assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False - trainer.node_rank = 0 - trainer.local_rank = 0 + node_rank.return_value = 0 + local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) - trainer.node_rank = 1 - trainer.local_rank = 0 + node_rank.return_value = 1 + local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() - trainer.node_rank = 0 - trainer.local_rank = 1 + node_rank.return_value = 0 + local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True - trainer.local_rank = 0 + local_rank.return_value = 0 # is_overridden prepare data = True # has been called @@ -392,7 +398,8 @@ def test_full_loop_dp(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_dm_transfer_batch_to_device(tmpdir): +@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +def test_dm_transfer_batch_to_device(get_module_mock): class CustomBatch: def __init__(self, data): self.samples = data[0] @@ -415,11 +422,10 @@ def transfer_batch_to_device(self, data, device): trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead - trainer.get_model = MagicMock(return_value=model) - - model.transfer_batch_to_device = dm.transfer_batch_to_device + get_module_mock.return_value = model + if is_overridden('transfer_batch_to_device', dm): + model.transfer_batch_to_device = dm.transfer_batch_to_device - trainer.accelerator_backend = GPUAccelerator(trainer) batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) expected = torch.device('cuda', 0) assert dm.hook_called diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 4d36027709900..f70fe53d9f44b 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -162,15 +162,15 @@ def configure_optimizers(self): optimizer_2 = Adam(self.layer.parameters(), lr=0.1) return [optimizer, optimizer_2] - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, - on_tpu=False, using_native_amp=False, using_lbfgs=False): - # warm up lr - if self.trainer.global_step < 500: - lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) - for pg in optimizer.param_groups: - pg['lr'] = lr_scale * 0.01 - - optimizer.step(closure=closure) + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * 0.01 + + optimizer.step(closure=optimizer_closure) model = TestModel() model.training_epoch_end = None diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index c9f6ea05ad2b8..94bfd6808ed79 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -20,6 +20,7 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -107,11 +108,17 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@mock.patch.dict(os.environ, { + "SLURM_NTASKS": "1", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" +}) def test_amp_gpu_ddp_slurm_managed(tmpdir): """Make sure DDP + AMP work.""" # simulate setting slurm flags tutils.set_random_master_port() - os.environ['SLURM_LOCALID'] = str(0) model = EvalModelTemplate() @@ -131,17 +138,17 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): callbacks=[checkpoint], logger=logger, ) - trainer.is_slurm_managing_tasks = True - trainer.fit(model) + result = trainer.fit(model) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete' # test root model address - assert trainer.slurm_connector.resolve_root_node_address('abc') == 'abc' - assert trainer.slurm_connector.resolve_root_node_address('abc[23]') == 'abc23' - assert trainer.slurm_connector.resolve_root_node_address('abc[23-24]') == 'abc23' - assert trainer.slurm_connector.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23' + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc') == 'abc' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23]') == 'abc23' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23' def test_cpu_model_with_amp(tmpdir): diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index dc8f3f1e4d50d..bcc3709d129cf 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -21,7 +21,6 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -161,6 +160,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu): pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"), pytest.param([0], [0]), pytest.param([1, 3], [1, 3]), + pytest.param((1, 3), [1, 3]), pytest.param('0', [0]), pytest.param('3', [3]), pytest.param('1, 3', [1, 3]), @@ -180,7 +180,6 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): pytest.param([-1]), pytest.param([None]), pytest.param(['0']), - pytest.param((0, 1)), ]) def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus): with pytest.raises(MisconfigurationException): @@ -210,7 +209,6 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_single_gpu_batch_parse(): trainer = Trainer(gpus=1) - trainer.accelerator_backend = GPUAccelerator(trainer) # non-transferrable types primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] @@ -306,7 +304,6 @@ def to(self, *args, **kwargs): def test_non_blocking(): """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """ trainer = Trainer() - trainer.accelerator_backend = GPUAccelerator(trainer) batch = torch.zeros(2, 3) with patch.object(batch, 'to', wraps=batch.to) as mocked: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index f45d3f423164d..227716d5e72c4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -13,13 +13,14 @@ # limitations under the License. import inspect import os +from unittest import mock from unittest.mock import MagicMock import pytest import torch +from unittest.mock import PropertyMock from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator from pytorch_lightning.trainer.states import TrainerState from tests.base import BoringModel, EvalModelTemplate, RandomDataset @@ -55,20 +56,19 @@ def test_training_epoch_end_metrics_collection(tmpdir): num_epochs = 3 class CurrentModel(EvalModelTemplate): - def training_step(self, *args, **kwargs): output = super().training_step(*args, **kwargs) - output['progress_bar'].update({'step_metric': torch.tensor(-1)}) - output['progress_bar'].update({'shared_metric': 100}) + output["progress_bar"].update({"step_metric": torch.tensor(-1)}) + output["progress_bar"].update({"shared_metric": 100}) return output def training_epoch_end(self, outputs): epoch = self.current_epoch # both scalar tensors and Python numbers are accepted return { - 'progress_bar': { - f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch - 'shared_metric': 111, + "progress_bar": { + f"epoch_metric_{epoch}": torch.tensor(epoch), # add a new metric key every epoch + "shared_metric": 111, } } @@ -83,19 +83,18 @@ def training_epoch_end(self, outputs): metrics = trainer.progress_bar_dict # metrics added in training step should be unchanged by epoch end method - assert metrics['step_metric'] == -1 + assert metrics["step_metric"] == -1 # a metric shared in both methods gets overwritten by epoch_end - assert metrics['shared_metric'] == 111 + assert metrics["shared_metric"] == 111 # metrics are kept after each epoch for i in range(num_epochs): - assert metrics[f'epoch_metric_{i}'] == i + assert metrics[f"epoch_metric_{i}"] == i @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_transfer_batch_hook(): - +@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +def test_transfer_batch_hook(model_getter_mock): class CustomBatch: - def __init__(self, data): self.samples = data[0] self.targets = data[1] @@ -117,11 +116,10 @@ def transfer_batch_to_device(self, data, device): batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) - trainer.accelerator_backend = GPUAccelerator(trainer) # running .fit() would require us to implement custom data loaders, we mock the model reference instead - trainer.get_model = MagicMock(return_value=model) - batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected = torch.device('cuda', 0) + model_getter_mock.return_value = model + batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device("cuda:0")) + expected = torch.device("cuda", 0) assert model.hook_called assert batch_gpu.samples.device == batch_gpu.targets.device == expected @@ -402,8 +400,8 @@ def teardown(self, stage: str): expected = [ 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', + # 'on_pretrain_routine_start', + # 'on_pretrain_routine_end', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 85e91c4ae9d84..429ad108f1fc6 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -26,7 +26,7 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.legacy.horovod_accelerator import HorovodAccelerator +from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.metrics.classification.accuracy import Accuracy from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE @@ -303,12 +303,12 @@ def _compute_batch(): accelerator='horovod', ) - accelerator_backend = trainer.accelerator_connector.select_accelerator() - assert isinstance(accelerator_backend, HorovodAccelerator) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + # TODO: test that we selected the correct training_type_plugin based on horovod flags metric = Accuracy(compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=accelerator_backend.gather_all_tensors, + dist_sync_fn=trainer.training_type_plugin.gather_all_tensors, threshold=threshold) for i in range(hvd.rank(), num_batches, hvd.size()): diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5e977eed765d0..20e9473b3a910 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader import tests.base.develop_pipelines as tpipes -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.trainer.states import TrainerState @@ -250,9 +250,9 @@ def test_broadcast_on_tpu(): """ Checks if an object from the master process is broadcasted to other processes correctly""" def test_broadcast(rank): trainer = Trainer(tpu_cores=8) - backend = TPUAccelerator(trainer) + assert isinstance(trainer.accelerator_backend, TPUAccelerator) obj = ("ver_0.5", "logger_name", rank) - result = backend.broadcast(obj) + result = trainer.accelerator_backend.broadcast(obj) assert result == ("ver_0.5", "logger_name", 0) xmp.spawn(test_broadcast, nprocs=8, start_method='fork') diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py new file mode 100644 index 0000000000000..bc4a21db554af --- /dev/null +++ b/tests/plugins/test_sharded_plugin.py @@ -0,0 +1,299 @@ +import os +import platform + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin, \ + ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base.boring_model import BoringModel + + +@pytest.mark.parametrize( + ["accelerator"], + [("ddp_sharded",), ("ddp_sharded_spawn",)] +) +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_sharded_ddp_choice(tmpdir, accelerator): + """ + Test to ensure that plugin is correctly chosen + """ + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + if accelerator == 'ddp_sharded': + assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPShardedPlugin) + elif accelerator == 'ddp_sharded_spawn': + assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPSpawnShardedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator=accelerator, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_sharded_spawn', + precision=16, + amp_backend='apex', + ) + + trainer.fit(model) + + +@pytest.mark.parametrize( + ["accelerator"], + [("ddp_sharded",), ("ddp_sharded_spawn",)] +) +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +def test_ddp_choice_sharded_amp(tmpdir, accelerator): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.precision_plugin, ShardedNativeMixedPrecisionPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=1, + precision=16, + accelerator=accelerator, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_finetune(tmpdir): + """ + Test to ensure that we can save and restart training (simulate fine-tuning) + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + ) + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + trainer = Trainer( + fast_dev_run=True, + ) + trainer.fit(saved_model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): + """ + Test to ensure that resuming from checkpoint works + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + + +@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") +@pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): + """ + Test to ensure that resuming from checkpoint works when downsizing number of GPUS + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + gpus=2, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + gpus=1, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): + """ + Test to ensure that resuming from checkpoint works when going from GPUs- > CPU + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + gpus=1, + fast_dev_run=True + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + ) + + trainer.test(model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test_multigpu(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + gpus=2, + fast_dev_run=True, + ) + + trainer.test(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index a93a722bba597..b3105e97e18c1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -129,7 +129,7 @@ def test_multiple_val_dataloader(tmpdir): # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: - tpipes.run_prediction(trainer.model, dataloader) + tpipes.run_prediction(trained_model=model, dataloader=dataloader) @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])