From dd84937158f093a008da5cd1a59aef3d7e1f1f1e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 30 Dec 2020 00:54:18 +0100 Subject: [PATCH 01/16] warnings --- pytorch_lightning/trainer/deprecated_api.py | 28 ++++++++++----------- tests/deprecated_api/test_remove_1-4.py | 24 ++++++++++++------ 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 2c8377d2936c9..461bce639fe85 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -23,7 +23,7 @@ class DeprecatedDistDeviceAttributes: @property def on_cpu(self) -> bool: - # rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._device_type == DeviceType.CPU @on_cpu.setter @@ -34,7 +34,7 @@ def on_cpu(self, val: bool) -> None: @property def on_tpu(self) -> bool: - # rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._device_type == DeviceType.TPU @on_tpu.setter @@ -45,7 +45,7 @@ def on_tpu(self, val: bool) -> None: @property def use_tpu(self) -> bool: - # rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.on_tpu @use_tpu.setter @@ -55,7 +55,7 @@ def use_tpu(self, val: bool) -> None: @property def on_gpu(self) -> bool: - # rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._device_type == DeviceType.GPU @on_gpu.setter @@ -66,7 +66,7 @@ def on_gpu(self, val: bool) -> None: @property def use_dp(self) -> bool: - # rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._distrib_type == DistributedType.DP @use_dp.setter @@ -77,8 +77,8 @@ def use_dp(self, val: bool) -> None: @property def use_ddp(self) -> bool: - # rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) - return self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + return self._distrib_type == DistributedType.DDP @use_ddp.setter def use_ddp(self, val: bool) -> None: @@ -88,7 +88,7 @@ def use_ddp(self, val: bool) -> None: @property def use_ddp2(self) -> bool: - # rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._distrib_type == DistributedType.DDP2 @use_ddp2.setter @@ -99,9 +99,9 @@ def use_ddp2(self, val: bool) -> None: @property def use_horovod(self) -> bool: - # rank_zero_warn( - # "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning - # ) + rank_zero_warn( + "Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning + ) return self._distrib_type == DistributedType.HOROVOD @use_horovod.setter @@ -114,9 +114,9 @@ def use_horovod(self, val: bool) -> None: @property def use_single_gpu(self) -> bool: - # rank_zero_warn( - # "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning, - # ) + rank_zero_warn( + "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning + ) # todo, limiting to exclude DDP2 is not clear but it comes from connectors... return (self._device_type and self._device_type == DeviceType.GPU and self.num_gpus == 1 diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index db514cd5dde46..03973e040b150 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -45,35 +45,43 @@ def test_v1_4_0_deprecated_trainer_attributes(): with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.on_cpu = True - assert trainer.on_cpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_cpu with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.on_gpu = True - assert trainer.on_gpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_gpu with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.on_tpu = True - assert trainer.on_tpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_tpu trainer._device_type = None with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_tpu = True - assert trainer.use_tpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_tpu with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_dp = True - assert trainer.use_dp + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_dp with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_ddp = True - assert trainer.use_ddp + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_ddp with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_ddp2 = True - assert trainer.use_ddp2 + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_ddp2 with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_horovod = True - assert trainer.use_horovod + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_horovod def test_v1_4_0_deprecated_metrics(): From 1f4a3e2344304401c94800631a5ddc5bf928c58e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 30 Dec 2020 01:11:25 +0100 Subject: [PATCH 02/16] . --- .../accelerators/accelerator_connector.py | 8 ++++---- .../accelerators/horovod_accelerator.py | 8 ++++---- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/gpu_stats_monitor.py | 2 +- pytorch_lightning/core/memory.py | 4 ++-- pytorch_lightning/core/optimizer.py | 4 ++-- pytorch_lightning/overrides/data_parallel.py | 2 +- pytorch_lightning/plugins/ddp_plugin.py | 3 ++- .../trainer/connectors/checkpoint_connector.py | 13 +++++++------ .../logger_connector/epoch_result_store.py | 3 ++- .../connectors/logger_connector/logger_connector.py | 6 +++--- .../trainer/connectors/model_connector.py | 12 +++++++----- .../trainer/connectors/slurm_connector.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 12 +++++++----- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- pytorch_lightning/tuner/lr_finder.py | 4 ++-- tests/base/model_test_epoch_ends.py | 10 ++++++---- 17 files changed, 55 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 7417f889dd808..1abd9b07be851 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -263,7 +263,7 @@ def select_accelerator(self): ddp_plugin=self.trainer.plugin_connector.ddp_plugin ) - elif self.trainer.use_dp: + elif self.trainer._distrib_type == DistributedType.DP: accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env) elif self.trainer.use_horovod: @@ -272,7 +272,7 @@ def select_accelerator(self): elif self.trainer.use_single_gpu: accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env) - elif self.trainer.use_tpu: + elif self.trainer._device_type == DeviceType.TPU: accelerator_backend = accelerators.TPUAccelerator(self.trainer, cluster_env) elif self.trainer.distributed_backend is None: @@ -353,7 +353,7 @@ def set_distributed_mode(self): 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' ) - rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}') + rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}') num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0 rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') @@ -366,7 +366,7 @@ def _set_horovod_backend(self): # Initialize Horovod to get rank / size info hvd.init() - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: # Horovod assigns one local GPU per process self.trainer.root_gpu = hvd.local_rank() diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index fec5e53492005..cc0297b4de017 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -19,7 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_only if _HOROVOD_AVAILABLE: @@ -46,7 +46,7 @@ def setup(self, model): # call setup after the ddp process has connected self.trainer.call_setup_hook(model) - if torch.cuda.is_available() and self.trainer.on_gpu: + if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU: # Horovod: pin GPU to local rank assert self.trainer.root_gpu == hvd.local_rank() torch.cuda.set_device(self.trainer.root_gpu) @@ -116,7 +116,7 @@ def train(self): return results def _step(self, model_step: Callable, args): - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: args[0] = self.batch_to_device(args[0], hvd.local_rank()) if self.trainer.amp_backend == AMPType.NATIVE: @@ -141,7 +141,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): optimizer.synchronize() def on_train_epoch_end(self, outputs): - hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1) + hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1) def barrier(self, name: Optional[str] = None): hvd.join() diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ec44a1eeb416b..e4084fcd769db 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -24,7 +24,7 @@ import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 1403d0bdf2e31..09c73d8e9b352 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -104,7 +104,7 @@ def on_train_start(self, trainer, *args, **kwargs): 'Cannot use GPUStatsMonitor callback with Trainer that has no logger.' ) - if not trainer.on_gpu: + if trainer._device_type != DeviceType.GPU: raise MisconfigurationException( 'You are using GPUStatsMonitor but are not running on GPU' f' since gpus attribute in Trainer is set to {trainer.gpus}.' diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index faafc0a0f0584..44c06dfe0f58d 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch.utils.hooks import RemovableHandle -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import AMPType, DeviceType PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] UNKNOWN_SIZE = "?" @@ -229,7 +229,7 @@ def _forward_example_input(self) -> None: input_ = model.example_input_array input_ = model.transfer_batch_to_device(input_, model.device) - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu: + if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 4e5ab14d91980..acba35d9ae0ac 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -17,7 +17,7 @@ from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: @@ -125,7 +125,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n optimizer = self._optimizer model = trainer.get_model() - if trainer.on_tpu: + if trainer._device_type == DeviceType.TPU: with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 1943a83644e29..1552611f57e16 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -285,7 +285,7 @@ def _worker(i, module, input, kwargs, device=None): if output is None: warn_missing_output(fx_called) - if output is not None and (module.use_dp or module.use_ddp2): + if output is not None and module.distrib_type in ("dp", "ddp2"): auto_squeeze_dim_zeros(output) # --------------- diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 16db194e97c97..f32e35c5e085e 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -22,6 +22,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.plugin import LightningPlugin +from pytorch_lightning.utilities import DeviceType class DDPPlugin(LightningPlugin): @@ -95,7 +96,7 @@ def init_ddp_connection( os.environ["MASTER_ADDR"] = str(cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(cluster_environment.world_size()) - torch_backend = "nccl" if trainer.on_gpu else "gloo" + torch_backend = "nccl" if trainer._device_type == DeviceType.GPU else "gloo" if not torch_distrib.is_initialized(): log.info( diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d46e0e4cf3503..c8d38591e70a6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,7 +21,8 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import ( + _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn, DeviceType) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -50,7 +51,7 @@ def restore_weights(self, model: LightningModule) -> None: 3. don't restore """ # clear cache before restore - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() # 1. Attempt to restore states from HPC checkpoint @@ -58,18 +59,18 @@ def restore_weights(self, model: LightningModule) -> None: max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - self.hpc_load(checkpoint_path, self.trainer.on_gpu) + self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU) rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: - self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) + self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') # clear cache after restore - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: @@ -291,7 +292,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump amp scaling if (self.trainer.amp_backend == AMPType.NATIVE - and not self.trainer.use_tpu + and self.trainer._device_type != DeviceType.TPU and self.trainer.scaler is not None): checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() elif self.trainer.amp_backend == AMPType.APEX: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2796a61ee5c83..cb3b0b3c235e6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -18,6 +18,7 @@ import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities import DistributedType class LoggerStages(str, Enum): @@ -343,7 +344,7 @@ def cache_result(self) -> None: hook_result.detach() if self.trainer.move_metrics_to_cpu: hook_result.cpu() - elif self.trainer.use_dp: + elif self.trainer._distrib_type == DistributedType.DP: hook_result.to(torch.device("cuda", self.trainer.root_gpu)) self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 73e9223fb7d0f..4d6e560fefdac 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import os +from copy import deepcopy from pprint import pprint from typing import Any, Iterable, Union, Dict @@ -24,7 +24,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder -from pytorch_lightning.utilities import flatten_dict +from pytorch_lightning.utilities import flatten_dict, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -219,7 +219,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics= and global_step for the rest. """ # add gpu memory - if self.trainer.on_gpu and self.trainer.log_gpu_memory: + if self.trainer._device_type == DeviceType.GPU and self.trainer.log_gpu_memory: mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory) metrics.update(mem_map) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index c5a8c48357b44..9a086aff1bd78 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -32,13 +32,15 @@ def copy_trainer_model_properties(self, model): for m in [model, ref_model]: m.trainer = self.trainer m.logger = self.trainer.logger - m.use_dp = self.trainer.use_dp - m.use_ddp2 = self.trainer.use_ddp2 - m.use_ddp = self.trainer.use_ddp + m.device_type = str(self.trainer._device_type) + m.distrib_type = str(self.trainer._distrib_type) + # m.use_dp = self.trainer.use_dp + # m.use_ddp2 = self.trainer.use_ddp2 + # m.use_ddp = self.trainer.use_ddp m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing - m.use_single_gpu = self.trainer.use_single_gpu - m.use_tpu = self.trainer.use_tpu + # m.use_single_gpu = self.trainer.use_single_gpu + # m.use_tpu = self.trainer.use_tpu m.tpu_local_core_rank = self.trainer.tpu_local_core_rank m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index e17235779f22b..5248aabbfa913 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -3,6 +3,7 @@ import signal from subprocess import call from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info import torch.distributed as torch_distrib import torch @@ -145,7 +146,7 @@ def connect_ddp(self, global_rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = root_node log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - torch_backend = "nccl" if self.trainer.on_gpu else "gloo" + torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo" if not torch.distributed.is_initialized(): log.info( diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 64eb224a428f1..714a4592d984c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum -from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing +from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -102,7 +102,7 @@ def should_skip_training(self): def on_train_start(self): # clear cache before training - if self.trainer.on_gpu and self.trainer.root_gpu is not None: + if self.trainer._device_type == DeviceType.GPU and self.trainer.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f"cuda:{self.trainer.root_gpu}"): @@ -152,7 +152,9 @@ def setup_training(self, model: LightningModule): self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work - if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: + if (self.trainer.amp_backend == AMPType.NATIVE + and self.trainer.precision == 16 + and self.trainer._device_type != DeviceType.TPU): self.trainer.scaler = self.trainer.precision_connector.backend.scaler # log hyper-parameters @@ -219,7 +221,7 @@ def on_train_end(self): self.trainer.accelerator_backend.on_train_end() # clear mem - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() @@ -508,7 +510,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ optimizer, opt_idx, train_step_and_backward_closure, - on_tpu=self.trainer.use_tpu and _TPU_AVAILABLE, + on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 52662f6172d8d..b20772c867b56 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -17,7 +17,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda from pytorch_lightning.loggers.base import DummyLogger @@ -115,7 +115,7 @@ def scale_batch_size(trainer, # Restore initial state of model if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e0fab12eec9d3..d4ee79f466b5b 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -29,7 +29,7 @@ from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, DeviceType from pytorch_lightning.utilities.cloud_io import get_filesystem # check if ipywidgets is installed before importing tqdm.auto @@ -192,7 +192,7 @@ def lr_find( # Reset model state if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/tests/base/model_test_epoch_ends.py b/tests/base/model_test_epoch_ends.py index 164a7d3671923..90084298b3187 100644 --- a/tests/base/model_test_epoch_ends.py +++ b/tests/base/model_test_epoch_ends.py @@ -15,6 +15,8 @@ import torch +from pytorch_lightning.utilities import DistributedType + class TestEpochEndVariations(ABC): @@ -33,13 +35,13 @@ def test_epoch_end(self, outputs): test_loss = self.get_output_metric(output, 'test_loss') # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = self.get_output_metric(output, 'test_acc') - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc @@ -68,13 +70,13 @@ def test_epoch_end__multiple_dataloaders(self, outputs): test_loss = output['test_loss'] # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = output['test_acc'] - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc From ae349a38e8cc546036f7fa7be97f305c7e820fa4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 30 Dec 2020 01:19:45 +0100 Subject: [PATCH 03/16] . --- .../accelerators/accelerator_connector.py | 17 +++++++-------- .../trainer/connectors/slurm_connector.py | 4 ++-- tests/backends/test_accelerator_connector.py | 21 ++++++++++--------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 1abd9b07be851..9680fa8f6610a 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -185,14 +185,13 @@ def select_accelerator(self): # ---------------------------------- # choose an accelerator for the user # ---------------------------------- - use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks + use_slurm_ddp = self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and self.trainer.is_slurm_managing_tasks # torchelastic or general non_slurm ddp te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) - use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed + use_torchelastic_ddp = self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed - use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn" - use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu" + use_ddp_cpu_spawn = self.trainer._distrib_type == DistributedType.DDP_SPAWN and self.trainer._device_type == DeviceType.CPU use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks @@ -205,7 +204,7 @@ def select_accelerator(self): cluster_env = self._select_environment() # choose the appropriate accelerator backend - if self.trainer.use_ddp2: + if self.trainer._distrib_type == DistributedType.DDP2: accelerator_backend = accelerators.DDP2Accelerator( self.trainer, cluster_env, @@ -240,7 +239,7 @@ def select_accelerator(self): self.trainer.plugin_connector.ddp_plugin ) - elif use_ddp_spawn: + elif self.trainer._distrib_type == DistributedType.DDP_SPAWN: accelerator_backend = accelerators.DDPSpawnAccelerator( self.trainer, nprocs=self.trainer.num_processes, @@ -266,10 +265,10 @@ def select_accelerator(self): elif self.trainer._distrib_type == DistributedType.DP: accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env) - elif self.trainer.use_horovod: + elif self.trainer._distrib_type == DistributedType.HOROVOD: accelerator_backend = accelerators.HorovodAccelerator(self.trainer, cluster_env) - elif self.trainer.use_single_gpu: + elif self.trainer._device_type == DeviceType.GPU and self.trainer.num_gpus == 1: accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env) elif self.trainer._device_type == DeviceType.TPU: @@ -347,7 +346,7 @@ def set_distributed_mode(self): self._set_horovod_backend() # throw error to force user ddp or ddp2 choice - if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP): + if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): raise MisconfigurationException( 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index 5248aabbfa913..22a8ee229ad4a 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -3,7 +3,7 @@ import signal from subprocess import call from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.distributed import rank_zero_info import torch.distributed as torch_distrib import torch @@ -23,7 +23,7 @@ def configure_slurm_ddp(self, num_gpu_nodes): # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes - if self.trainer.use_ddp or self.trainer.use_ddp2: + if self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): self.trainer.num_requested_gpus = self.trainer.num_gpus * num_gpu_nodes self.trainer.num_slurm_tasks = 0 try: diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index dc8bf338d3eb3..0ba08f72fbd30 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -21,6 +21,7 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.cluster_environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.utilities import DistributedType from tests.base.boring_model import BoringModel @@ -41,7 +42,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSpawnAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) raise SystemExit() @@ -63,7 +64,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) raise SystemExit() @@ -85,7 +86,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_spawn(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPSpawnAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) raise SystemExit() @@ -113,7 +114,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -144,7 +145,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp2 + assert trainer._distrib_type == DistributedType.DDP2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -174,7 +175,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -203,7 +204,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp2 + assert trainer._distrib_type == DistributedType.DDP2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -231,7 +232,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -262,7 +263,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) raise SystemExit() @@ -298,7 +299,7 @@ def master_address(self): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, CustomCluster) raise SystemExit() From 553854f8a0d33f0f14d70dc7add2c465029bb696 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 30 Dec 2020 11:57:10 +0100 Subject: [PATCH 04/16] flake8 --- .../accelerators/accelerator_connector.py | 13 +++++++++---- pytorch_lightning/callbacks/gpu_stats_monitor.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 9680fa8f6610a..5de8c54986e91 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -185,13 +185,16 @@ def select_accelerator(self): # ---------------------------------- # choose an accelerator for the user # ---------------------------------- - use_slurm_ddp = self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and self.trainer.is_slurm_managing_tasks + use_slurm_ddp = (self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer.is_slurm_managing_tasks) # torchelastic or general non_slurm ddp te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) - use_torchelastic_ddp = self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed + use_torchelastic_ddp = (self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and te_flags_passed) - use_ddp_cpu_spawn = self.trainer._distrib_type == DistributedType.DDP_SPAWN and self.trainer._device_type == DeviceType.CPU + use_ddp_cpu_spawn = (self.trainer._distrib_type == DistributedType.DDP_SPAWN + and self.trainer._device_type == DeviceType.CPU) use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks @@ -352,7 +355,9 @@ def set_distributed_mode(self): 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' ) - rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}') + rank_zero_info( + f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}' + ) num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0 rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 09c73d8e9b352..3b8ab457c5f12 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -27,7 +27,7 @@ from typing import Dict, List, Tuple from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import rank_zero_only, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict From 4e7c49f93dffc1c0eb79a8d30df51b798303dacc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 21:30:27 +0100 Subject: [PATCH 05/16] . --- .../accelerators/accelerator_connector.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 5de8c54986e91..cc3f704d46a90 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -185,16 +185,21 @@ def select_accelerator(self): # ---------------------------------- # choose an accelerator for the user # ---------------------------------- - use_slurm_ddp = (self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - and self.trainer.is_slurm_managing_tasks) + use_slurm_ddp = ( + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer.is_slurm_managing_tasks + ) # torchelastic or general non_slurm ddp te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) - use_torchelastic_ddp = (self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - and te_flags_passed) + use_torchelastic_ddp = ( + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed + ) - use_ddp_cpu_spawn = (self.trainer._distrib_type == DistributedType.DDP_SPAWN - and self.trainer._device_type == DeviceType.CPU) + use_ddp_cpu_spawn = ( + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer._device_type == DeviceType.CPU + ) use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks From 1893a03190fed2e7a9650f98d7e392478f836ed8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:05:07 +0100 Subject: [PATCH 06/16] . --- pytorch_lightning/accelerators/accelerator_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index cc3f704d46a90..fbf2a3668d9ff 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -211,6 +211,7 @@ def select_accelerator(self): cluster_env = self._select_environment() + # TODO: clean-up this branching as most just select class and uses the very same arguments # choose the appropriate accelerator backend if self.trainer._distrib_type == DistributedType.DDP2: accelerator_backend = accelerators.DDP2Accelerator( From 9c4be4b89aeb42526854bf0a03de80345cdee481 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:16:07 +0100 Subject: [PATCH 07/16] . --- .../accelerators/accelerator_connector.py | 13 +++++++------ pytorch_lightning/callbacks/early_stopping.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index fbf2a3668d9ff..f04e3704550ff 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -186,19 +186,19 @@ def select_accelerator(self): # choose an accelerator for the user # ---------------------------------- use_slurm_ddp = ( - self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - and self.trainer.is_slurm_managing_tasks + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer.is_slurm_managing_tasks ) # torchelastic or general non_slurm ddp te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) use_torchelastic_ddp = ( - self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed ) use_ddp_cpu_spawn = ( - self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - and self.trainer._device_type == DeviceType.CPU + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer._device_type == DeviceType.CPU ) use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() @@ -355,7 +355,8 @@ def set_distributed_mode(self): self._set_horovod_backend() # throw error to force user ddp or ddp2 choice - if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): + _ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + if (self.trainer.num_nodes > 1 and self.trainer._distrib_type not in _ddp): raise MisconfigurationException( 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index e4084fcd769db..ec44a1eeb416b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -24,7 +24,7 @@ import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE, DeviceType +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException From 88f4f01a7f08797c0920c8c59d3cce86f852af28 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:30:57 +0100 Subject: [PATCH 08/16] use_tpu --- .../connectors/logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 13 +++++++------ pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 4d6e560fefdac..db41011e57d6c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -81,7 +81,7 @@ def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) model_ref = self.trainer.get_model() metrics_holder.convert( - self.trainer.use_tpu, + self.trainer._device_type == DeviceType.TPU, model_ref.device if model_ref is not None else model_ref ) return metrics_holder.metrics diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index eeb5b2f0cd4e5..430c6a6cde66e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -27,7 +27,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType from pytorch_lightning.utilities.argparse import ( from_argparse_args, parse_argparser, parse_env_variables, add_argparse_args ) @@ -48,9 +48,8 @@ class TrainerProperties(ABC): _state: TrainerState global_rank: int fast_dev_run: Union[int, bool] - use_dp: bool - use_ddp: bool - use_ddp2: bool + _device_type: DeviceType + _distrib_type: DistributedType model: LightningModule data_parallel_device_ids: Optional[List[int]] _progress_bar_callback: ProgressBarBase @@ -62,6 +61,8 @@ class TrainerProperties(ABC): model_connector: ModelConnector checkpoint_connector: CheckpointConnector callbacks: List[Callback] + num_nodes: int + num_processes: int @property def log_dir(self): @@ -275,14 +276,14 @@ def __setstate__(self, d): def require_distributed_sampler(self): if self.accelerator_backend is not None: return self.accelerator_backend.require_distributed_sampler - return self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + return self.use_ddp or self.use_ddp2 or self.use_horovod or self._device_type == DeviceType.TPU @property def distributed_sampler_kwargs(self): if self.accelerator_backend is not None: return self.accelerator_backend.distributed_sampler_kwargs - if self.use_tpu: + if self._device_type == DeviceType.TPU: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.use_horovod: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b923ae9adce0c..ab7b411311ecc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -787,7 +787,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} - if self.accelerator_backend is not None and not self.use_tpu: + if self.accelerator_backend is not None and not self._device_type == DeviceType.TPU: self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) From d2be8af81b88d3083496526f0fe0f6993a3b9108 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:36:31 +0100 Subject: [PATCH 09/16] use_dp --- pytorch_lightning/trainer/logging.py | 14 +++++++------- pytorch_lightning/trainer/properties.py | 4 +++- tests/base/deterministic_model.py | 5 +++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 976762a5b4711..1dd567de89c68 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -19,6 +19,7 @@ import torch from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -28,13 +29,12 @@ class TrainerLoggingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class current_epoch: int - on_gpu: bool + _device_type: DeviceType + _distrib_type: DistributedType log_gpu_memory: ... logger: Union[LightningLoggerBase, bool] global_step: int global_rank: int - use_dp: bool - use_ddp2: bool default_root_dir: str slurm_job_id: int num_gpus: int @@ -96,7 +96,7 @@ def process_dict_result(self, output, train=False): if k not in ['progress_bar', 'log', 'hiddens']: callback_metrics[k] = v - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) @@ -107,7 +107,7 @@ def process_dict_result(self, output, train=False): progress_output = output['progress_bar'] # reduce progress metrics for progress bar when using dp - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output(progress_output, num_gpus) @@ -124,7 +124,7 @@ def process_dict_result(self, output, train=False): log_output = output['log'] # reduce progress metrics for progress bar when using dp - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output(log_output, num_gpus) @@ -152,7 +152,7 @@ def process_dict_result(self, output, train=False): ) from exp # when using dp need to reduce the loss - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): loss = self.reduce_distributed_output(loss, self.num_gpus) # --------------- diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 430c6a6cde66e..c33dcff37af8a 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -177,7 +177,9 @@ def num_gpus(self) -> int: @property def data_parallel(self) -> bool: - return self.use_dp or self.use_ddp or self.use_ddp2 + return self._distrib_type in ( + DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 + ) @property def progress_bar_callback(self): diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index 04e8e57b2e569..cc3bbc8c54be3 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -16,6 +16,7 @@ from torch.utils.data import Dataset, DataLoader from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import DistributedType class DeterministicModel(LightningModule): @@ -99,7 +100,7 @@ def training_epoch_end_scalar(self, outputs): """ self.training_epoch_end_called = True - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): pass else: # only saw 4 batches @@ -160,7 +161,7 @@ def training_step_end_dict(self, output): def training_epoch_end_dict(self, outputs): self.training_epoch_end_called = True - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): pass else: # only saw 4 batches From 41a8ad43403d4f3fb046914dc892afd9c8185366 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:40:30 +0100 Subject: [PATCH 10/16] . --- pytorch_lightning/trainer/deprecated_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 461bce639fe85..aaa1ba47adf73 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -78,7 +78,7 @@ def use_dp(self, val: bool) -> None: @property def use_ddp(self) -> bool: rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) - return self._distrib_type == DistributedType.DDP + return self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) @use_ddp.setter def use_ddp(self, val: bool) -> None: From 3c1d3a10865a1a79879c01dfb746558feef8d9c6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:46:42 +0100 Subject: [PATCH 11/16] use_ddp --- pytorch_lightning/core/lightning.py | 3 +++ pytorch_lightning/trainer/data_loading.py | 1 - pytorch_lightning/trainer/properties.py | 4 +++- tests/base/develop_pipelines.py | 5 +++-- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c2ec67819912e..7a9153a5079dd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -85,6 +85,9 @@ def __init__(self, *args, **kwargs): #: Pointer to the logger object self.logger = None + self._distrib_type = None + self._device_type = None + #: True if using dp self.use_dp = False diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 3db83c415aded..b4f4626dba310 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -38,7 +38,6 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class global_rank: int - use_ddp: bool use_ddp2: bool use_horovod: bool shown_warnings: ... diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c33dcff37af8a..e3da7430138b4 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -278,7 +278,9 @@ def __setstate__(self, d): def require_distributed_sampler(self): if self.accelerator_backend is not None: return self.accelerator_backend.require_distributed_sampler - return self.use_ddp or self.use_ddp2 or self.use_horovod or self._device_type == DeviceType.TPU + return self._distrib_type in ( + DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 + ) or self._device_type == DeviceType.TPU @property def distributed_sampler_kwargs(self): diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index c4197741a0791..4949d53fc9a50 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -15,6 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import DistributedType from tests.base import BoringModel from tests.base.develop_utils import get_default_logger, load_model_from_checkpoint, reset_seed @@ -43,7 +44,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 for dataloader in test_loaders: run_prediction(pretrained_model, dataloader, min_acc=min_acc) - if trainer.use_ddp: + if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() @@ -81,7 +82,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, run_prediction(pretrained_model, dataloader, min_acc=min_acc) if with_hpc: - if trainer.use_ddp or trainer.use_ddp2: + if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( From 57057647a419b55c5ef3546a02052a479e37a1ec Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:50:49 +0100 Subject: [PATCH 12/16] . --- pytorch_lightning/core/lightning.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7a9153a5079dd..dd5691d6e4553 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -88,18 +88,6 @@ def __init__(self, *args, **kwargs): self._distrib_type = None self._device_type = None - #: True if using dp - self.use_dp = False - - #: True if using ddp - self.use_ddp = False - - #: True if using ddp2 - self.use_ddp2 = False - - # True if on tpu - self.use_tpu = False - #: True if using amp self.use_amp = False From c7a671939e2960b45d1b0406b212d74d57054212 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:51:58 +0100 Subject: [PATCH 13/16] use_horovod --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index e3da7430138b4..786e775668ca2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -290,7 +290,7 @@ def distributed_sampler_kwargs(self): if self._device_type == DeviceType.TPU: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - elif self.use_horovod: + elif self._distrib_type == DistributedType.HOROVOD: kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) else: From db6fbd65f39d8f6bed6015122d9050af4d2651d5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:53:55 +0100 Subject: [PATCH 14/16] . --- pytorch_lightning/trainer/connectors/model_connector.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 9a086aff1bd78..a3759d1075ee5 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -32,15 +32,10 @@ def copy_trainer_model_properties(self, model): for m in [model, ref_model]: m.trainer = self.trainer m.logger = self.trainer.logger - m.device_type = str(self.trainer._device_type) - m.distrib_type = str(self.trainer._distrib_type) - # m.use_dp = self.trainer.use_dp - # m.use_ddp2 = self.trainer.use_ddp2 - # m.use_ddp = self.trainer.use_ddp + m._device_type = str(self.trainer._device_type) + m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing - # m.use_single_gpu = self.trainer.use_single_gpu - # m.use_tpu = self.trainer.use_tpu m.tpu_local_core_rank = self.trainer.tpu_local_core_rank m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision From ce5a4ab2221a93f62a2b3a4e2b57c750702f182f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 10 Jan 2021 22:57:21 +0100 Subject: [PATCH 15/16] . --- pytorch_lightning/trainer/data_loading.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index b4f4626dba310..fa3bd2092945a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -38,11 +38,8 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class global_rank: int - use_ddp2: bool - use_horovod: bool shown_warnings: ... val_check_interval: float - use_tpu: bool tpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] From a42e8c0b5af338f3ae8bbd9b09dea8aefaae9f2e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 11 Jan 2021 20:15:00 +0100 Subject: [PATCH 16/16] . --- pytorch_lightning/overrides/data_parallel.py | 2 +- pytorch_lightning/utilities/enums.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 1552611f57e16..f6f045134f2f9 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -285,7 +285,7 @@ def _worker(i, module, input, kwargs, device=None): if output is None: warn_missing_output(fx_called) - if output is not None and module.distrib_type in ("dp", "ddp2"): + if output is not None and module._distrib_type in ('dp', 'ddp2'): auto_squeeze_dim_zeros(output) # --------------- diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index bbcb83b6ee15a..5ff8f81fe1f0b 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -50,7 +50,7 @@ class DistributedType(LightningEnum): >>> DistributedType.DDP == 'ddp' True >>> # which is case invariant - >>> DistributedType.DDP2 == 'DDP2' + >>> DistributedType.DDP2 in ('ddp2', ) True """ DP = 'dp' @@ -69,7 +69,7 @@ class DeviceType(LightningEnum): >>> DeviceType.GPU == 'GPU' True >>> # which is case invariant - >>> DeviceType.TPU == 'tpu' + >>> DeviceType.TPU in ('tpu', 'CPU') True """ CPU = 'CPU'