diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e26dc8b476ab2..1596c898f88c3 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -374,3 +374,6 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: def on_save(self, checkpoint): return checkpoint + + def barrier(self, name: Optional[str] = None) -> None: + self.training_type_plugin.barrier(name=name) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 1fa95ef4c13b5..5c1d3eb6ebed2 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Optional, Sequence import torch @@ -26,15 +27,21 @@ DataParallelPlugin, DDP2Plugin, DDPPlugin, + DDPShardedPlugin, DDPSpawnPlugin, + DDPSpawnShardedPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, + RPCPlugin, ShardedNativeMixedPrecisionPlugin, SingleDevicePlugin, SingleTPUPlugin, TPUHalfPrecisionPlugin, - TPUSpawnPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin, + TPUSpawnPlugin, + TrainingTypePlugin, + DDPShardedPlugin, + DDPSpawnShardedPlugin, ) from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus @@ -74,6 +81,7 @@ def __init__( amp_type, amp_level, cluster_environment, + plugins, ): # initialization self._device_type = DeviceType.CPU @@ -95,6 +103,11 @@ def __init__( self.cluster_environment = cluster_environment self.is_slurm_managing_tasks = False + self._precision_plugin: Optional[PrecisionPlugin] = None + self._training_type_plugin: Optional[TrainingTypePlugin] = None + + self.handle_given_plugins(plugins) + # init the default rank if exists # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks # this way we only show it on rank 0 @@ -136,6 +149,56 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp + def handle_given_plugins(self, plugins: Optional[Sequence]): + if plugins is None: + return + + if not isinstance(plugins, Sequence): + plugins = [plugins] + + training_type = None + precision = None + + for plug in plugins: + if isinstance(plug, TrainingTypePlugin): + if training_type is None: + training_type = plug + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin. ' + 'Found more than 1 training type plugin' + ) + elif isinstance(plug, PrecisionPlugin): + if precision is None: + precision = plug + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin. ' + 'Found more than 1 precision plugin' + ) + else: + raise MisconfigurationException( + f'Found invalid type for plugin {plug}. ' + 'Expected a precision or training type plugin.' + ) + + self._training_type_plugin = training_type + self._precision_plugin = precision + + @property + def precision_plugin(self) -> PrecisionPlugin: + if self._precision_plugin is None: + self._precision_plugin = self.select_precision_plugin() + + return self._precision_plugin + + @property + def training_type_plugin(self) -> TrainingTypePlugin: + if self._training_type_plugin is None: + self._training_type_plugin = self.select_training_type_plugin() + + return self._training_type_plugin + @property def on_cpu(self): return self._device_type == DeviceType.CPU @@ -205,6 +268,9 @@ def select_precision_plugin(self): if self.on_tpu: return TPUHalfPrecisionPlugin() + if isinstance(self.training_type_plugin, RPCPlugin): + raise MisconfigurationException + if self.amp_type == "native": if not _NATIVE_AMP_AVAILABLE: rank_zero_warn( @@ -215,7 +281,7 @@ def select_precision_plugin(self): self.amp_type = "apex" else: log.info("Using native 16bit precision.") - if self.distributed_backend == "ddp_sharded" or self.distributed_backend == "ddp_sharded_spawn": + if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): return ShardedNativeMixedPrecisionPlugin() self.amp_type = AMPType.NATIVE return NativeMixedPrecisionPlugin() @@ -227,7 +293,7 @@ def select_precision_plugin(self): " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) else: - if self.distributed_backend == "ddp_sharded" or self.distributed_backend == "ddp_sharded_spawn": + if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP, " "please using native AMP for 16-bit precision." @@ -289,6 +355,12 @@ def select_training_type_plugin(self): def select_accelerator(self): if isinstance(self.distributed_backend, Accelerator): # custom accelerator from user + if self._precision_plugin is not None or self._training_type_plugin is not None: + # plugins also specified by user + rank_zero_warn( + 'Specified Precision and TrainingType Plugins will be ignored, ' + 'since an Accelerator instance was provided' + ) return self.distributed_backend if self.on_gpu: @@ -299,8 +371,8 @@ def select_accelerator(self): acc_cls = CPUAccelerator return acc_cls( - precision_plugin=self.select_precision_plugin(), - training_type_plugin=self.select_training_type_plugin(), + precision_plugin=self.precision_plugin, + training_type_plugin=self.training_type_plugin, ) def select_cluster_environment(self): diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 66ed4e5126400..4fb4827bfd991 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,9 +1,17 @@ +from typing import Callable + +import torch + from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin +from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException +if _XLA_AVAILABLE: + import torch_xla.core.xla_model as xm + class TPUAccelerator(Accelerator): @@ -17,3 +25,16 @@ def setup(self, trainer, model): if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") return super().setup(trainer, model) + + def optimizer_step( + self, optimizer: torch.optim.Optimizer, current_epoch: int, batch_idx: int, opt_idx: int, + lambda_closure: Callable + ): + + self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) + + xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure}) + + self.precision_plugin.post_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.post_optimizer_step(optimizer, opt_idx) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d39e600820735..2de2684bd9bc0 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,6 +196,7 @@ def _run_early_stopping_check(self, trainer, pl_module): if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 + should_stop = False else: self.wait_count += 1 should_stop = self.wait_count >= self.patience diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 14ab52c3c6fba..4bb1e81885852 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -275,7 +275,6 @@ def log( raise MisconfigurationException( f"Logged key: {name} should not contain information about dataloader_idx.") - accelerator = self.trainer.accelerator_backend training_type_plugin = self.trainer.training_type_plugin self._results.log( diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index e75a5568aae0f..a55982562ff1b 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -20,9 +20,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - def is_lightning_optimizer(optimizer): return isinstance(optimizer, LightningOptimizer) @@ -133,18 +130,10 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n optimizer = self._optimizer model = trainer.get_model() - if trainer._device_type == DeviceType.TPU: - with trainer.profiler.profile(profiler_name): - xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) - - # elif trainer.amp_backend is not None: - # # TODO: Adapt for new optimizer structure - # trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure) - - else: - with trainer.profiler.profile(profiler_name): - optimizer.step(closure=closure, *args, **kwargs) - + with trainer.profiler.profile(profiler_name): + trainer.accelerator_backend.optimizer_step(*args, lambda_closure=closure, **kwargs) + + # TODO: Do we need this? accelerator_backend = trainer.accelerator_backend if accelerator_backend is not None and accelerator_backend.rpc_enabled: if accelerator_backend.ddp_plugin.is_main_rpc_process: diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 91cebaee2bd4c..d4ac91edaba61 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -11,6 +11,10 @@ from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 @@ -19,17 +23,8 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 __all__ = [ - "ApexMixedPrecisionPlugin", - "DataParallelPlugin", - "DDP2Plugin", - "DDPPlugin", - "DDPSpawnPlugin", - "HorovodPlugin", - "NativeMixedPrecisionPlugin", - "PrecisionPlugin", - "ShardedNativeMixedPrecisionPlugin", - "SingleDevicePlugin", - "SingleTPUPlugin", - "TPUHalfPrecisionPlugin", - "TPUSpawnPlugin", + "ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", "DDPPlugin", "DDPSpawnPlugin", "HorovodPlugin", + "NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", "SingleDevicePlugin", + "SingleTPUPlugin", "TPUHalfPrecisionPlugin", "TPUSpawnPlugin", 'RPCPlugin', 'RPCSequentialPlugin' + 'TrainingTypePlugin', 'ParallelPlugin', 'Plugin', 'DDPShardedPlugin', 'DDPSpawnShardedPlugin' ] diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 1d1f203afa38a..32d73c46e21c1 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -4,6 +4,8 @@ from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.plugins.training_type.rpc import RPCPlugin +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 91d44fbdaa5d1..e8e3559246c81 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io from abc import ABC, abstractmethod from contextlib import contextmanager from typing import List, Optional @@ -22,7 +23,7 @@ from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin -from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp class ParallelPlugin(TrainingTypePlugin, ABC): @@ -102,3 +103,13 @@ def block_backward_sync(self): yield self.model.no_sync() else: yield None + + def broadcast(self, obj: object, src: int) -> object: + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float) + data = all_gather_ddp_if_available(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index eeda8ab81bdf3..755d138bc17d6 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -93,6 +93,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: return dataloader is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 20438f427d315..9875b0b038935 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -141,6 +141,7 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers + class _MockOptimizer(Optimizer): """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from `configure_optimizers`. diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 39dcbc6c7c3e0..bce8192db9d4b 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,11 +15,12 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import cast, List, Optional, Type, TypeVar, Union, Any +from typing import Any, cast, List, Optional, Type, TypeVar, Union from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator_connector import BackendConnector -from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -46,6 +47,10 @@ from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loggers.base import LightningLoggerBase +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.utilities.model_utils import is_overridden + class TrainerProperties(ABC): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5cdfa5021acb8..f19dc661fb27d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -31,7 +31,6 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.plugins.legacy.plugin_connector import PluginConnector from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator @@ -306,7 +305,7 @@ def __init__( self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) - self.plugin_connector = PluginConnector(self, plugins) + self.accelerator_connector = BackendConnector( num_processes, tpu_cores, @@ -321,7 +320,8 @@ def __init__( precision, amp_backend, amp_level, - self.plugin_connector.cloud_environment + self.plugin_connector.cloud_environment, + plugins ) self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) @@ -1057,16 +1057,6 @@ def call_hook(self, hook_name, *args, **kwargs): self._cache_logged_metrics() return output - @staticmethod - def available_plugins(): - """ - List of all available plugins that can be string arguments to the trainer. - - Returns: - List of all available plugins that are supported as string arguments. - """ - return PluginConnector.available_plugins() - @property def training(self) -> bool: return self._running_stage == RunningStage.TRAINING diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 695741ed3cd22..f7a86dbfcbc36 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -24,6 +24,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result +from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing