diff --git a/.yapfignore b/.yapfignore index 6def4861b4858..48c75600b1fa2 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1,5 +1 @@ .git/* - - -# TODO -pytorch_lightning/plugins/legacy/* diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 6619eed0209c6..b235e6a458e6c 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -580,9 +580,9 @@ Below are the possible configurations we support. Implement Your Own Distributed (DDP) training ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you need your own way to init PyTorch DDP you can override :meth:`pytorch_lightning.plugins.legacy.ddp_plugin.DDPPlugin.init_ddp_connection`. +If you need your own way to init PyTorch DDP you can override :meth:`pytorch_lightning.plugins.training_type.ddp.DDPPlugin.init_ddp_connection`. -If you also need to use your own DDP implementation, override :meth:`pytorch_lightning.plugins.legacy.ddp_plugin.DDPPlugin.configure_ddp`. +If you also need to use your own DDP implementation, override :meth:`pytorch_lightning.plugins.training_type.ddp.DDPPlugin.configure_ddp`. ---------- @@ -679,7 +679,7 @@ In addition, we use Gradient Checkpointing to reduce GPU memory requirements fur Reference: https://arxiv.org/abs/1811.06965 -.. note:: DDPSequentialPlugin is currently supported only for Pytorch 1.6. +.. note:: RPCSequentialPlugin is currently supported only for Pytorch 1.6. To get started, install FairScale using the command below. We install a specific branch which contains PyTorch related fixes for Sequential Parallelism. @@ -692,7 +692,7 @@ This should be kept within the ``sequential_module`` variable within your ``Ligh .. code-block:: python - from pytorch_lightning.plugins.legacy.ddp_sequential_plugin import DDPSequentialPlugin + from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin from pytorch_lightning import LightningModule class MyModel(LightningModule): @@ -702,7 +702,7 @@ This should be kept within the ``sequential_module`` variable within your ``Ligh # Split my module across 4 gpus, one layer each model = MyModel() - plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1]) + plugin = RPCSequentialPlugin(balance=[1, 1, 1, 1]) trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin]) trainer.fit(model) diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 81de052b3fd0a..9e3461ec54634 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -18,7 +18,7 @@ to balance across your GPUs. To run: -python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential +python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_rpc_sequential """ import math from argparse import ArgumentParser @@ -32,7 +32,7 @@ from pl_examples import cli_lightning_logo from pytorch_lightning import Trainer from pytorch_lightning.metrics.functional import accuracy -from pytorch_lightning.plugins.legacy.ddp_sequential_plugin import DDPSequentialPlugin +from pytorch_lightning.plugins import RPCSequentialPlugin from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE if _BOLTS_AVAILABLE: @@ -201,7 +201,7 @@ def instantiate_datamodule(args): if __name__ == "__main__": cli_lightning_logo() parser = ArgumentParser(description="Pipe Example") - parser.add_argument("--use_ddp_sequential", action="store_true") + parser.add_argument("--use_rpc_sequential", action="store_true") parser = Trainer.add_argparse_args(parser) parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser) args = parser.parse_args() @@ -212,8 +212,8 @@ def instantiate_datamodule(args): cifar10_dm = instantiate_datamodule(args) plugins = None - if args.use_ddp_sequential: - plugins = DDPSequentialPlugin() + if args.use_rpc_sequential: + plugins = RPCSequentialPlugin() model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization) @@ -223,4 +223,4 @@ def instantiate_datamodule(args): if trainer.accelerator_backend.rpc_enabled: # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + trainer.training_type_plugin.exit_rpc_process() diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index 41af4fe84c7f0..c9e054c032804 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.legacy.plugin import LightningPlugin - -class ClusterEnvironment(LightningPlugin): +class ClusterEnvironment: def __init__(self): self._world_size = None diff --git a/pytorch_lightning/plugins/legacy/__init__.py b/pytorch_lightning/plugins/legacy/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/pytorch_lightning/plugins/legacy/apex.py b/pytorch_lightning/plugins/legacy/apex.py deleted file mode 100644 index 6968296e1ff7f..0000000000000 --- a/pytorch_lightning/plugins/legacy/apex.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Tuple, Union - -import torch -from torch.optim.optimizer import Optimizer - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.plugins.legacy.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_warn - -if _APEX_AVAILABLE: - from apex import amp - - -class ApexPlugin(PrecisionPlugin): - - def __init__(self, trainer=None): - self.trainer = trainer - - def connect(self, model, optimizers): - model, optimizers = self.configure_apex(amp, model, optimizers, self.trainer.amp_level) - self.trainer.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers) - return model, optimizers - - def training_step(self, fx, args): - output = fx(args) - return output - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - closure_loss = amp.scale_loss(closure_loss, optimizer) - - # enter apex context - self.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX)) - context = closure_loss - closure_loss = closure_loss.__enter__() - - # do backward pass - if self.trainer.train_loop.automatic_optimization: - model = self.trainer.get_model() - model.backward(closure_loss, optimizer, opt_idx) - else: - closure_loss.backward(*args, **kwargs) - - # exit amp context - a, b, c = None, None, None - error = context.__exit__(a, b, c) - if error: - rank_zero_warn(a, b, c) - raise Exception('apex unscale error') - - # once backward has been applied, release graph - closure_loss = closure_loss.detach() - return closure_loss - - def configure_apex( - self, - amp: object, - model: LightningModule, - optimizers: List[Optimizer], - amp_level: str, - ) -> Tuple[LightningModule, List[Optimizer]]: - r""" - Override to init AMP your own way. - Must return a model and list of optimizers. - - Args: - amp: pointer to amp library object. - model: pointer to current :class:`LightningModule`. - optimizers: list of optimizers passed in :meth:`configure_optimizers`. - amp_level: AMP mode chosen ('O1', 'O2', etc...) - - Return: - Apex wrapped model and optimizers - - Examples:: - - # Default implementation used by Trainer. - def configure_apex(self, amp, model, optimizers, amp_level): - model, optimizers = amp.initialize( - model, optimizers, opt_level=amp_level, - ) - - return model, optimizers - """ - model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) - return model, optimizers - - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): - """ - This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. - This is important when setting amp_level to O2, and the master weights are in fp16. - - Args: - grad_clip_val: Maximum norm of gradients. - optimizer: Optimizer with gradients that will be clipped. - norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - """ - model = self.trainer.get_model() - parameters = model.parameters() - max_norm = float(grad_clip_val) - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] - - if len(parameters) == 0: - return torch.tensor(0.) - device = parameters[0].grad.device - total_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) - clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef.to(p.grad.device)) - - @property - def norm_clipping_epsilon(self): - return 1e-5 - - def optimizer_step(self, trainer, optimizer, closure): - # apex amp does not yet support closures. - # TODO: pass the closure to the step ASAP - with trainer.profiler.profile("closure"): - closure() - - if not self.trainer.train_loop.automatic_optimization: - trainer.call_hook("on_after_backward") - - with trainer.profiler.profile("optimizer_step"): - optimizer.step() diff --git a/pytorch_lightning/plugins/legacy/ddp_plugin.py b/pytorch_lightning/plugins/legacy/ddp_plugin.py deleted file mode 100644 index 4d7303dd7035f..0000000000000 --- a/pytorch_lightning/plugins/legacy/ddp_plugin.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from contextlib import contextmanager -from typing import Any, Dict, List, Union - -import torch.distributed as torch_distrib -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import Optimizer - -from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.distributed import LightningDistributedModule, prepare_for_backward -from pytorch_lightning.plugins.legacy.plugin import LightningPlugin -from pytorch_lightning.utilities import DeviceType - - -class DDPPlugin(LightningPlugin): - """ - Plugin to link a custom ddp implementation to any arbitrary accelerator. - - This plugin forwards all constructor arguments to :class:`~torch.nn.parallel.DistributedDataParallel`. - - Example:: - - class MyDDP(DDPPlugin): - - def configure_ddp(self, model, device_ids): - model = MyDDPWrapper(LightningDistributedModule(model), device_ids) - return model - - my_ddp = MyDDP() - trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp]) - """ - - def __init__(self, **kwargs): - self._ddp_kwargs: Dict[str, Any] = kwargs - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - """ - Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`. - Override to define a custom DDP implementation. - - .. note:: This requires that your DDP implementation subclasses - :class:`~torch.nn.parallel.DistributedDataParallel` and that - the original LightningModule gets wrapped by - :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`. - - The default implementation is:: - - def configure_ddp(self, model, device_ids): - model = DistributedDataParallel( - LightningDistributedModule(model), - device_ids=device_ids, - **self._ddp_kwargs, - ) - return model - - Args: - model: the LightningModule - device_ids: the list of devices available - - Returns: - the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel` - - """ - # if unset, default `find_unused_parameters` `True` - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( - "find_unused_parameters", True - ) - model = DistributedDataParallel( - module=LightningDistributedModule(model), - device_ids=device_ids, - **self._ddp_kwargs, - ) - return model - - def init_ddp_connection( - self, - trainer, - cluster_environment, - global_rank: int, - world_size: int, - is_slurm_managing_tasks: bool = True, - ) -> None: - # Todo: required argument `is_slurm_managing_tasks` is not used - os.environ["MASTER_ADDR"] = str(cluster_environment.master_address()) - os.environ["MASTER_PORT"] = str(cluster_environment.master_port()) - os.environ["WORLD_SIZE"] = str(cluster_environment.world_size()) - torch_backend = "nccl" if trainer._device_type == DeviceType.GPU else "gloo" - - if not torch_distrib.is_initialized(): - log.info( - f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" - ) - torch_distrib.init_process_group( - torch_backend, rank=global_rank, world_size=world_size - ) - - @property - def is_running_single_process_per_device(self) -> bool: - # objects do not need to be scattered in single process per device, move objects upfront to device - # This property is used in ``self.on_before_forward`` function. - return self.device_ids is not None and len(self.device_ids) == 1 - - def on_before_forward(self, model: LightningModule, *args): - """ - Override to handle custom edge case. - - Args: - args: Inputs to the model. - model: Model to train. - - Returns: - args moved to correct device if needed. - """ - if self.is_running_single_process_per_device: - args = model.transfer_batch_to_device(args, model.device) - return args - - def optimizer_state(self, optimizer: Optimizer) -> dict: - return optimizer.state_dict() - - def on_after_setup_optimizers(self, trainer): - """ - Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or - state sharding. - """ - - def get_model_from_plugin( - self, - model: Union[DistributedDataParallel, LightningModule] - ) -> LightningModule: - """ - Override to modify returning base :class:`LightningModule` - when accessing variable and functions outside of the parallel wrapper. - - Example:: - ref_model = ddp_plugin.get_model_from_plugin(model) - ref_model.training_step(...) - - Args: - model: Model with parallel wrapper. - - Returns: - Reference :class:`LightningModule` within parallel wrapper. - - """ - if isinstance(model, DistributedDataParallel): - model = model.module - if isinstance(model, LightningDistributedModule): - model = model.module - return model - - @contextmanager - def block_backward_sync(self, model: DistributedDataParallel): - """ - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - - Returns: - context manager with sync behaviour off - """ - yield model.no_sync() - - def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): - prepare_for_backward(model, output) - - def distributed_sampler_kwargs(self, distributed_sampler_kwargs): - return distributed_sampler_kwargs - - @property - def data_parallel_group(self): - """ - Return the group that this process exists in. By default, this is the world size. - Useful for when additional parallel groups have been created, to select certain processes. - - Returns: - The ProcessGroup this process exists in. - """ - return torch_distrib.group.WORLD diff --git a/pytorch_lightning/plugins/legacy/ddp_sequential_plugin.py b/pytorch_lightning/plugins/legacy/ddp_sequential_plugin.py deleted file mode 100644 index 1a6a18c206f81..0000000000000 --- a/pytorch_lightning/plugins/legacy/ddp_sequential_plugin.py +++ /dev/null @@ -1,411 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -from typing import Any, List, Optional - -import torch -import torch.distributed as torch_distrib -from torch import nn -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning import LightningModule -from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _FAIRSCALE_PIPE_AVAILABLE: - import fairscale.nn.model_parallel as mpu - from fairscale.nn import PipeRPCWrapper - from fairscale.nn.pipe import balance as pipe_balance - from fairscale.nn.pipe import rpc as rpc_pipe - from fairscale.nn.pipe.pipeline import PipelineStyle - - -class DDPSequentialPlugin(RPCPlugin): - def __init__( - self, - balance: Optional[List[int]] = None, - microbatches: int = 8, - checkpoint: str = 'except_last', - balance_mode: str = "balance_by_size", - pipelined_backward: Optional[bool] = True, - **kwargs): - """ - Provides sequential model parallelism for :class:`nn.Sequential ` module. - If the module requires lots of memory, Pipe can be used to reduce this by leveraging multiple GPUs. - - Example:: - class MyLightningModule: - def __init__(self): - ... - model.sequential_module = torch.nn.Sequential(my_layers) - - # Split my module across 4 gpus, one layer each - model = MyLightningModule() - plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1]) - trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin]) - trainer.fit(model) - - .. _DDPSequentialPlugin: https://arxiv.org/abs/1811.06965 - - Pipeline parallelism comes with with checkpointing to reduce peak - memory required to train while minimizing device under-utilization. - This is turned on by default and can be turned off via the checkpoint argument. - - You should determine the balance when defining the plugin, - or you can pass an example input array via the LightningModule to infer a balance. - The module will be partitioned into multiple devices according to the given balance. You may also rely on - your own heuristics to find your own optimal configuration. - - Args: - balance: The balance of the model, i.e [2, 2] (two layers on each GPU). - If not provided assumes user provides an input example array to find a balance on all GPUs. - - microbatches: Allows for parallelization to reduce device utilization - by splitting the batch into further smaller batches. - - checkpoint: Enables gradient checkpointing. ['always', 'except_last', 'never'] - - balance_mode: Type of balance heuristic to use if balance to be inferred. - - - 'balance_by_size': checks memory usage of each layer and determines balance - - - 'balance_by_time': checks time of each layer and determines balance - - pipelined_backward: if True, call torch.autograd.backward once per microbatch on the - - backward pass (instead of once for the whole batch). This works - around a potential deadlock in pytorch when using tensor parallelism - at the same time. Defaults to `True` if - `get_model_parallel_world_size() > 1` - """ - self._check_pipe_available() - super().__init__(**kwargs) - - self.balance = balance - - self.microbatches = microbatches - self.checkpoint = checkpoint - self.balance_mode = balance_mode - self.pipelined_backward = pipelined_backward - self.main_rpc_process = False # Updated by main process, default for all secondary processes - - def init_ddp_connection( - self, - trainer, - cluster_environment, - global_rank: int, - world_size: int, - is_slurm_managing_tasks: bool = True, - ) -> None: - trainer.prepared_for_backwards = False - self._check_arguments(trainer) - if self._skip_init_connections(trainer): - return - super().init_ddp_connection( - trainer=trainer, - cluster_environment=cluster_environment, - global_rank=global_rank, - world_size=world_size, - is_slurm_managing_tasks=is_slurm_managing_tasks - ) - super().init_rpc_connection( - global_rank=global_rank, - world_size=world_size - ) - model = trainer.get_model() - self.gpus_per_model = self._infer_check_num_gpus(trainer) - self.init_model_parallel_groups(trainer) - self.set_main_rpc_process() - - self._check_sequential_model_exists(model) - if self.main_rpc_process: - if self.balance is None: - self._infer_model_balance(trainer) - self._assert_valid_model_balance(trainer) - - def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): - pass - - def _infer_model_balance(self, trainer): - log.info(f'Inferring model balance using {self.balance_mode} mode') - model = trainer.get_model() - if model.example_input_array is None: - raise MisconfigurationException( - 'Please set example_input_array to your model, so we can infer the right model balance for you') - balance_func = getattr(pipe_balance, self.balance_mode) - self.balance = balance_func(self.gpus_per_model, model.sequential_module, model.example_input_array) - self._sync_balance_to_all_parallel_groups() - - log.info(f'The following model balance {self.balance.tolist()} was inferred using {self.balance_mode} mode') - - def _sync_balance_to_all_parallel_groups(self, main_rank=0): - """ - Ensures that we sync the balance to all main processes, so that the balance is the same per replica. - - Args: - main_rank: The rank with the balance we'd like to replicate. - """ - self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda') - # Ensure we sync to all processes within the main data parallel group - # We use the data parallel group as all main processes are found within the same group - torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group()) - self.balance = self.balance.cpu() - - def _check_sequential_model_exists(self, model): - if not hasattr(model, "sequential_module") or not isinstance(model.sequential_module, nn.Sequential): - raise MisconfigurationException( - 'Could not find a PipeLightningModule within the model. ' - 'Did you set your sequential model as the `sequential_module` attribute of your model?') - - def _find_and_init_pipe_module(self, model): - if hasattr(model, "sequential_module") and isinstance(model.sequential_module, LightningPipeModule): - # model has been wrapped already - return - elif hasattr(model, "sequential_module") and isinstance(model.sequential_module, nn.Sequential): - # try to wrap model for the user - model.sequential_module = LightningPipeModule( - model.sequential_module, - balance=self.balance, - microbatches=self.microbatches, - checkpoint=self.checkpoint, - ) - # Update references for workers to access correct lightning functions when calling RPC - model.sequential_module.trainer = model.trainer - model.sequential_module.configure_optimizers = model.configure_optimizers - - # Update references for main process to access correct lightning functions when calling RPC - model.sequential_module.module.model.trainer = model.trainer - model.sequential_module.module.model.configure_optimizers = model.configure_optimizers - - else: - raise MisconfigurationException( - 'Could not find a PipeLightningModule within the model. ' - 'Did you defined set your sequential model as an `sequential_module` attribute of your model ?' - ) - - def _assert_valid_model_balance(self, trainer): - model = trainer.get_model() - if sum(self.balance) != len(model.sequential_module): - raise MisconfigurationException( - f'The provided balance sum: {sum(self.balance)} does not' - f' match your Sequential length: {len(model.sequential_module)}') - - def _skip_init_connections(self, trainer): - """ - Skip initialization if torch is already initialized and we're in testing. - - Returns: - Whether to skip initialization - - """ - return torch_distrib.is_initialized() and trainer.testing - - def init_model_parallel_groups(self, trainer): - num_model_parallel = 1 # TODO currently no support for vertical model parallel - mpu.initialize_model_parallel( - model_parallel_size_=num_model_parallel, - pipeline_length=self.gpus_per_model - ) - - def _infer_check_num_gpus(self, trainer): - """ - Infer the number of GPUs per model. - - Args: - trainer: The trainer object. - - Returns: - The appropriate balance for the model - """ - if isinstance(self.balance, list): - if len(self.balance) != (trainer.world_size / trainer.num_nodes): - raise MisconfigurationException( - "Pipe currently only supports splitting the module onto all available GPUs" - ) - # User has defined a balance for his model - return len(self.balance) - # Assume that the user wants to balance his model on all GPUs - return trainer.world_size - - def on_accelerator_exit_rpc_process(self, trainer) -> None: - if not trainer.testing: - torch_distrib.barrier() # Ensure we await main process initialization - - # Add trainer/configure_optimizers to the pipe model for access in all worker processes - rpc_pipe.PipeModel.trainer = trainer - del rpc_pipe.PipeModel.trainer.model.sequential_module - rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel - rpc_pipe.PipeModel.configure_optimizers = trainer.model.configure_optimizers - super().on_accelerator_exit_rpc_process(trainer) - - def set_main_rpc_process(self): - self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0 - - def on_main_rpc_connection(self, trainer) -> None: - # Create pipe_module - model = trainer.get_model() - self._find_and_init_pipe_module(model) - if not trainer.testing: - torch_distrib.barrier() # Ensure we join main process initialization - model.sequential_module.foreach_worker(register_optimizers, include_self=True) - - def _check_arguments(self, trainer): - if trainer.amp_backend is not None: - raise MisconfigurationException( - 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision') - - def configure_ddp( - self, - model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: - model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids) - # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel - model.require_backward_grad_sync = False - return model - - @rank_zero_only - def rpc_save_model( - self, - save_model_fn, - last_filepath, - trainer, - pl_module) -> None: - model = trainer.get_model() - if not hasattr(model.sequential_module, "foreach_worker"): - return - current_layers = pl_module.sequential_module - model.sequential_module.foreach_worker( - save_layers_on_all_rank_zero_workers, - {"gpus_per_model": self.gpus_per_model}, - include_self=True - ) - pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model) - save_model_fn(last_filepath, trainer, pl_module) - pl_module.sequential_module = current_layers - - def worker_optimizer_step( - self, - model: LightningModule, - opt_idx: int, - *args, - **kwargs) -> None: - model.sequential_module.foreach_worker( - run_optimizer, - {"opt_idx": opt_idx, "args": args, "kwargs": kwargs}, - include_self=False - ) - - def distributed_sampler_kwargs(self, distributed_sampler_kwargs): - return dict( - num_replicas=mpu.get_data_parallel_world_size(), - rank=mpu.get_data_parallel_rank(), - ) - - @property - def data_parallel_group(self): - return mpu.get_data_parallel_group() - - @property - def is_main_rpc_process(self) -> bool: - return self.main_rpc_process - - @property - def return_after_exit_rpc_process(self) -> bool: - return True - - def barrier(self, name: Optional[str] = None) -> None: - if torch_distrib.is_initialized() and self.is_main_rpc_process: - torch_distrib.barrier(group=self.data_parallel_group) - - def _check_pipe_available(self): - if not _FAIRSCALE_PIPE_AVAILABLE: - raise MisconfigurationException( - 'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.' - ) - - -class LightningPipeModule(nn.Module): - """ - This class wraps Fairscale Pipe and PipeRCPWrapper class. - """ - - def __init__( - self, - module: nn.Sequential, - balance: List[int], - microbatches: int = 8, - checkpoint='never'): - super().__init__() - self.module = module - self.balance = balance - self.microbatches = microbatches - self.checkpoint = checkpoint - self._init_pipe() - - def _init_pipe(self): - device = torch.device("cuda", torch_distrib.get_rank()) - - self.module = PipeRPCWrapper( - module=self.module, - balance=self.balance, - chunks=self.microbatches, - style=PipelineStyle.MultiProcess, - input_device=device, - worker_map=self.get_worker_map(), - checkpoint=self.checkpoint, - ) - - def foreach_worker(self, *args, **kwargs): - self.module.foreach_worker(*args, **kwargs) - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - - def get_worker_map(self): - # TODO, is this correct with multinodes? We also assume "worker" is the same as defined in the RPCPlugin - return {rank: f"worker{rank}" for rank in range(torch_distrib.get_world_size())} - - -def register_optimizers(ctx, model): - optimizers, lr_schedulers, optimizer_frequencies = model.trainer.init_optimizers(model) - model.trainer.optimizers = optimizers - model.trainer.lr_schedulers = lr_schedulers - model.trainer.optimizer_frequencies = optimizer_frequencies - - -def run_optimizer(ctx, model): - trainer = model.trainer - opt_idx = ctx["opt_idx"] - optimizer = trainer.optimizers[opt_idx] - optimizer.step(*ctx["args"], **ctx["kwargs"]) - - -def save_layers_on_all_rank_zero_workers(ctx, model): - gpus_per_model = ctx["gpus_per_model"] - rank = torch_distrib.get_rank() - if rank in range(gpus_per_model): - seq = list(model.children())[0] - torch.save(seq, f"seq_{rank}.pt") - - -def load_sequential_from_saved_layers(gpus_per_model): - partial_seqs = [torch.load(f"seq_{rank}.pt", map_location='cpu') for rank in range(gpus_per_model)] - seq = nn.Sequential() - for p_seq in partial_seqs: - for name, child in p_seq.named_children(): - seq.add_module(name, child) - # delete tmp files - [os.remove(f"seq_{rank}.pt") for rank in range(gpus_per_model)] - return seq diff --git a/pytorch_lightning/plugins/legacy/native_amp.py b/pytorch_lightning/plugins/legacy/native_amp.py deleted file mode 100644 index 0a38a90acb79f..0000000000000 --- a/pytorch_lightning/plugins/legacy/native_amp.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Union - -import torch -from torch.optim import Optimizer - -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.plugins.legacy.precision_plugin import PrecisionPlugin - - -class NativeAMPPlugin(PrecisionPlugin): - - def __init__(self, trainer=None): - """ - Integrates native amp into Lightning's internals. - """ - self.trainer = trainer - - def connect(self, model, optimizers): - return model, optimizers - - def training_step(self, fx, args): - with torch.cuda.amp.autocast(): - output = fx(*args) - return output - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - closure_loss = self.trainer.scaler.scale(closure_loss) - - automatic_optimization = self.trainer.train_loop.automatic_optimization - - # do backward pass - if automatic_optimization: - model = self.trainer.get_model() - model.backward(closure_loss, optimizer, opt_idx) - else: - closure_loss.backward(*args, **kwargs) - - # once backward has been applied, release graph - closure_loss = closure_loss.detach() - - # unscale gradient to allow analyze within `on_after_backward` - if not self.trainer.train_loop.should_accumulate() and automatic_optimization: - if isinstance(optimizer, LightningOptimizer): - self.trainer.scaler.unscale_(optimizer.optimizer) - else: - self.trainer.scaler.unscale_(optimizer) - - return closure_loss - - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): - model = self.trainer.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) - - @property - def scaler(self): - return torch.cuda.amp.GradScaler() - - def optimizer_step(self, trainer, optimizer, closure): - # native amp does not yet support closures. - # TODO: pass the closure to the step ASAP - with trainer.profiler.profile("closure"): - closure() - - if not self.trainer.train_loop.automatic_optimization: - trainer.scaler.unscale_(optimizer) - trainer.call_hook("on_after_backward") - - with trainer.profiler.profile("optimizer_step"): - trainer.scaler.step(optimizer) - trainer.scaler.update() diff --git a/pytorch_lightning/plugins/legacy/plugin.py b/pytorch_lightning/plugins/legacy/plugin.py deleted file mode 100644 index c02cfa5a19848..0000000000000 --- a/pytorch_lightning/plugins/legacy/plugin.py +++ /dev/null @@ -1,30 +0,0 @@ -from pytorch_lightning.utilities import AMPType - - -class LightningPlugin: - """ - Defines base class for Plugins. Plugins represent functionality that can be injected into the lightning codebase. - """ - - def required_plugins(self, amp_backend: AMPType, trainer) -> list: - """ - Override to define additional required plugins. This is useful for when custom plugins - need to enforce override of other plugins. - - Returns: - Optional list of plugins containing additional plugins. - - Example:: - - class MyPlugin(DDPPlugin): - def required_plugins(self): - return [MyCustomAMPPlugin()] - - # Will automatically add the necessary AMP plugin - trainer = Trainer(plugins=[MyPlugin()]) - - # Crash as MyPlugin enforces custom AMP plugin - trainer = Trainer(plugins=[MyPlugin(), NativeAMPPlugin()]) - - """ - return [] diff --git a/pytorch_lightning/plugins/legacy/plugin_connector.py b/pytorch_lightning/plugins/legacy/plugin_connector.py deleted file mode 100644 index 22f97bf8b77f3..0000000000000 --- a/pytorch_lightning/plugins/legacy/plugin_connector.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from enum import Enum -from typing import List, Optional, Union - -from pytorch_lightning.plugins.environments import ClusterEnvironment -from pytorch_lightning.plugins.legacy.apex import ApexPlugin -from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.legacy.native_amp import NativeAMPPlugin -from pytorch_lightning.plugins.legacy.plugin import LightningPlugin -from pytorch_lightning.plugins.legacy.sharded_plugin import DDPShardedPlugin -from pytorch_lightning.utilities import AMPType, rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class PluginConnector: - - def __init__(self, trainer): - self.trainer = trainer - self.plugins = [] - self.ddp_plugin = DDPPlugin() - self.cloud_environment = None - - def on_trainer_init(self, plugins: Optional[Union[str, list]]): - self.plugins = plugins - if self.plugins is None: - self.plugins = [] - self.plugins = self._convert_str_custom_plugins(self.plugins) - self.plugins = self._append_required_plugins(self.plugins) - self.__attach_ddp() - self.__attach_cluster() - self.__attach_amp() - self.__attach_apex() - - def __attach_amp(self): - amp_plugin = self.__attach_plugin(NativeAMPPlugin) - if amp_plugin: - self.trainer.amp_backend = AMPType.NATIVE - self.trainer.precision_connector.backend = amp_plugin - - def __attach_apex(self): - apex_plugin = self.__attach_plugin(ApexPlugin) - if apex_plugin: - self.trainer.amp_backend = AMPType.APEX - self.trainer.precision_connector.backend = apex_plugin - - def __attach_plugin(self, plugin_type, limit=1): - count = 0 - plugin_result = None - for plugin in self.plugins: - if isinstance(plugin, plugin_type): - - # count the clusters - count += 1 - if count > limit: - m = f'you can only use one {plugin_type.__class__} in plugins. You passed in: {count}' - raise MisconfigurationException(m) - - plugin_result = plugin - - return plugin_result - - def __attach_ddp(self, limit=1): - count = 0 - for plugin in self.plugins: - if isinstance(plugin, DDPPlugin): - - # count the clusters - count += 1 - if count > limit: - m = f'you can only use one DDP plugin in plugins. You passed in: {count}' - raise MisconfigurationException(m) - - # set the ddp plugin - self.ddp_plugin = plugin - - def __attach_cluster(self, limit=1): - num_clusters = 0 - for plugin in self.plugins: - if isinstance(plugin, ClusterEnvironment): - - # count the clusters - num_clusters += 1 - if num_clusters > limit: - m = f'you can only use one cluster environment in plugins. You passed in: {num_clusters}' - raise MisconfigurationException(m) - - # set the cluster - self.cloud_environment = plugin - - def _convert_str_custom_plugins(self, plugins: Union[str, list]): - """ - Converts string inputs to corresponding supported lightning plugins. - - Args: - plugins: List of plugins or string to choose lightning plugin. - - Returns: - List of plugins where strings are now plugins. - """ - if isinstance(plugins, str): - return [self._convert_str_to_plugin(plugins)] - return [self._convert_str_to_plugin(plugin) for plugin in plugins] - - def _convert_str_to_plugin(self, plugin): - if isinstance(plugin, str): - if plugin not in LightningCustomPlugins.__members__: - raise MisconfigurationException( - f" {plugin} is not a supported lightning custom plugin." - " If you're trying to pass a custom plugin, please pass this as an object to" - " Trainer(plugins=[MyPlugin()]." - f" Supported plugins as string input: {[e.name for e in LightningCustomPlugins]}." - ) - plugin_cls = LightningCustomPlugins[plugin].value - return plugin_cls(trainer=self.trainer) - return plugin - - def _append_required_plugins(self, plugins: List[LightningPlugin]): - """ - Allows custom plugins to define additional plugins. This is useful for when custom plugins - need to enforce override of native amp/apex when they are enabled. - - Args: - plugins: List of plugins - - Returns: - List of plugins containing additional plugins if needed. - - Example:: - - class MyPlugin(DDPPlugin): - def required_plugins(self): - return [MyCustomAMPPlugin()] - - # Will automatically add the necessary AMP plugin - trainer = Trainer(plugins=[MyPlugin()]) - - # Crash as MyPlugin enforces custom AMP plugin - trainer = Trainer(plugins=[MyPlugin(), NativeAMPPlugin()]) - - """ - for plugin in plugins: - required_plugins = plugin.required_plugins(amp_backend=self.trainer.amp_backend, trainer=self.trainer) - if required_plugins: - rank_zero_warn( - f'plugin {type(plugin)} has added additional required plugins as default:' - f' {[type(x) for x in required_plugins]}' - 'Extend this plugin and override `required_plugins`' - 'if this conflicts with your additional plugins.' - ) - plugins += required_plugins - return plugins - - @classmethod - def available_plugins(cls): - """ - List of all available plugins that can be string arguments to the trainer. - - Returns: - List of all available plugins that are supported as string arguments. - """ - return [e.name for e in LightningCustomPlugins] - - -class LightningCustomPlugins(Enum): - """ - String support for custom lightning plugins. - Allows easier access to custom lightning plugins from the command line. - """ - ddp_sharded = DDPShardedPlugin - native_amp = NativeAMPPlugin - apex_amp = ApexPlugin diff --git a/pytorch_lightning/plugins/legacy/precision_plugin.py b/pytorch_lightning/plugins/legacy/precision_plugin.py deleted file mode 100644 index 1041e9d6b0faf..0000000000000 --- a/pytorch_lightning/plugins/legacy/precision_plugin.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Union - -from torch.optim import Optimizer - -from pytorch_lightning.plugins.legacy.plugin import LightningPlugin - - -class PrecisionPlugin(LightningPlugin): - """ - Abstract class to extend for precision support (32/16 etc). - - This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded - training. - """ - - def connect(self, model, optimizers): - raise NotImplementedError - - def training_step(self, fx, args): - raise NotImplementedError - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - raise NotImplementedError - - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): - raise NotImplementedError diff --git a/pytorch_lightning/plugins/legacy/rpc_plugin.py b/pytorch_lightning/plugins/legacy/rpc_plugin.py deleted file mode 100644 index 7d9a6946c74ff..0000000000000 --- a/pytorch_lightning/plugins/legacy/rpc_plugin.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from contextlib import suppress -from typing import Optional - -import torch - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import _RPC_AVAILABLE - -DEFAULT_RPC_TIMEOUT_SEC = 60. -if _RPC_AVAILABLE: - from torch.distributed import rpc - with suppress(ModuleNotFoundError, ImportError): - from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC - - -class RPCPlugin(DDPPlugin): - """ - Backbone for RPC Plugins built on top of DDP. - RPC introduces different communication behaviour than DDP. Unlike DDP, processes potentially are not - required to run the same code as the main process. - This leads to edge cases where logic needs to be re-defined. This class contains special cases - that need to be addressed when using RPC communication when building custom RPC Plugins. - """ - - def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs): - self.rpc_timeout_sec = rpc_timeout_sec - self._is_rpc_initialized = False - super().__init__(**kwargs) - - def init_rpc_connection(self, - global_rank: int, - world_size: int) -> None: - os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') - rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) - rpc._set_rpc_timeout(self.rpc_timeout_sec) - self._is_rpc_initialized = True - - def rpc_save_model(self, - save_model_fn, - last_filepath, - trainer, - pl_module) -> None: - """ - Override to save model to disk. - This is required as the main process will be required to handle aggregating model states from RPC processes. - - Args: - save_model_fn: The saving function to save final model. - last_filepath: The filepath to save the model to. - trainer: The trainer object. - pl_module: The LightningModule. - """ - raise NotImplementedError - - def on_main_rpc_connection(self, trainer) -> None: - """ - Called when main rpc connection has been established. - - Args: - trainer: The trainer object. - """ - raise NotImplementedError - - def on_accelerator_exit_rpc_process(self, trainer) -> None: - """ - Called to exit RPC process within the accelerator, that is being managed by main process. - - Args: - trainer: The trainer object. - """ - self.exit_rpc_process() - - def exit_rpc_process(self): - if self._is_rpc_initialized: - torch.distributed.rpc.shutdown() - self._is_rpc_initialized = False - - @property - def return_after_exit_rpc_process(self) -> bool: - """ - Override to decide whether to skip train/test function after shutdown completed. - Usually RPC shutdown is a join/exit function, afterwards we want to exit the process. - - Returns: - Whether to return after rpc exit. - """ - raise NotImplementedError - - def worker_optimizer_step(self, - model: LightningModule, - opt_idx: int, - *args, - **kwargs) -> None: - """ - Called when optimizer step is run on the main process. Used to signal any RPC workers to run optimizer step. - - Args: - model: The LightningModule. - opt_idx: The idx of the optimizer to carry out step on. - """ - raise NotImplementedError - - @property - def is_main_rpc_process(self) -> bool: - """ - Override to add logic to determine current process is main RPC process. - """ - raise NotImplementedError - - def barrier(self, name: Optional[str] = None) -> None: - """ - Override to define distributed sync communication. This needs to be handled differently due to - the RPC connection managing certain processes at the same time. - """ - raise NotImplementedError diff --git a/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py deleted file mode 100644 index b2523ef3fce0a..0000000000000 --- a/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import cast, Union - -from torch.optim import Optimizer - -from pytorch_lightning.plugins.legacy.native_amp import NativeAMPPlugin -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE - -if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: - from fairscale.optim import OSS - from fairscale.optim.grad_scaler import ShardedGradScaler - - -class ShardedNativeAMPPlugin(NativeAMPPlugin): - @property - def scaler(self): - return ShardedGradScaler() - - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): - max_norm = grad_clip_val - norm_type = float(2.0) - optimizer = cast(OSS, optimizer) - optimizer.clip_grad_norm(max_norm, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/legacy/sharded_plugin.py b/pytorch_lightning/plugins/legacy/sharded_plugin.py deleted file mode 100644 index a30f0c891514c..0000000000000 --- a/pytorch_lightning/plugins/legacy/sharded_plugin.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional, Union - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.optimizer import is_lightning_optimizer -from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.legacy.sharded_native_amp_plugin import ShardedNativeAMPPlugin -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, AMPType, rank_zero_only -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _FAIRSCALE_AVAILABLE: - from fairscale.optim import OSS - - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel - - -class DDPShardedPlugin(DDPPlugin): - - def __init__(self, **kwargs): - self._check_fairscale() - super().__init__(**kwargs) - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ): - self._wrap_optimizers(model) - return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) - - def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: - optimizer.consolidate_state_dict() - return self._optim_state_dict(optimizer) - - def _check_fairscale(self): - if not _FAIRSCALE_AVAILABLE: - raise MisconfigurationException( - 'Sharded DDP Plugin requires Fairscale to be installed.' - ) - - @rank_zero_only - def _optim_state_dict(self, optimizer): - return optimizer.state_dict() - - def _wrap_optimizers(self, model): - trainer = model.trainer - if trainer.testing: - return - - self._reinit_with_fairscale_oss(trainer) - - def _reinit_with_fairscale_oss(self, trainer): - optimizers = trainer.optimizers - for x, optimizer in enumerate(optimizers): - if is_lightning_optimizer(optimizer): - optimizer = optimizer.optimizer - if not isinstance(optimizer, OSS): - optim_class = type(optimizer) - zero_optimizer = OSS( - params=optimizer.param_groups, - optim=optim_class, - **optimizer.defaults - ) - optimizers[x] = zero_optimizer - del optimizer - - def get_model_from_plugin( - self, - model: Union['LightningShardedDataParallel', LightningModule] - ) -> LightningModule: - if isinstance(model, LightningShardedDataParallel): - return model.module - return model - - def required_plugins(self, amp_backend: AMPType, trainer) -> list: - if amp_backend == AMPType.APEX: - raise MisconfigurationException( - 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision.' - ) - if amp_backend == AMPType.NATIVE: - return [ShardedNativeAMPPlugin(trainer=trainer)] - return [] - - def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any): - pass diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 331cbe76639f3..345f208b97cde 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -42,7 +42,7 @@ class RPCSequentialPlugin(RPCPlugin): def __init__( self, - balance: List[int], + balance: Optional[List[int]] = None, microbatches: int = 8, checkpoint: str = 'except_last', balance_mode: str = "balance_by_size", diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py deleted file mode 100644 index fdb469effa5e9..0000000000000 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytorch_lightning import _logger as log -from pytorch_lightning.plugins.legacy.apex import ApexPlugin -from pytorch_lightning.plugins.legacy.native_amp import NativeAMPPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn - - -class PrecisionConnector: - - def __init__(self, trainer): - self.trainer = trainer - self.backend = None - - def on_trainer_init(self, precision: int, amp_level: str, amp_backend: str): - # AMP init - # These are the only lines needed after v0.8.0 - # we wrap the user's forward with autocast and give it back at the end of fit - self.trainer.autocast_original_forward = None - self.trainer.precision = precision - self.trainer.scaler = None - - self.trainer.amp_level = amp_level - self.init_amp(amp_backend) - - def init_amp(self, amp_type: str): - assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' - self.trainer.amp_backend = None - self._setup_amp_backend(amp_type) - - def _setup_amp_backend(self, amp_type: str): - if self.trainer.precision != 16: - # no AMP requested, so we can leave now - return - - amp_type = amp_type.lower() - assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' - if amp_type == 'native': - if not _NATIVE_AMP_AVAILABLE: - rank_zero_warn( - 'You have asked for native AMP but your PyTorch version does not support it.' - ' Consider upgrading with `pip install torch>=1.6`.' - ' We will attempt to use NVIDIA Apex for this session.' - ) - amp_type = 'apex' - else: - self.trainer.amp_backend = AMPType.NATIVE - log.info('Using native 16bit precision.') - self.backend = NativeAMPPlugin(self.trainer) - - if amp_type == 'apex': - if not _APEX_AVAILABLE: - rank_zero_warn( - 'You have asked for Apex AMP but you have not installed it yet.' - ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux' - ) - else: - log.info('Using APEX 16bit precision.') - self.trainer.amp_backend = AMPType.APEX - self.backend = ApexPlugin(self.trainer) - log.warn("LightningOptimizer doesn't support Apex") - - if not self.trainer.amp_backend: - raise ModuleNotFoundError( - f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' - f' Consider installing torch >= 1.6 or NVIDIA Apex.' - ) - - def connect(self, model): - if self.backend: - model, optimizers = self.backend.connect(model, self.trainer.optimizers) - self.trainer.optimizers = optimizers - - return model diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4f9c5d4f5e19f..72236b1589b1d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -54,7 +54,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -409,10 +409,6 @@ def setup_trainer(self, model: LightningModule): model: The model to run sanity test on. """ - # init amp. Must be done here instead of __init__ to allow ddp to work - if self.amp_backend == AMPType.NATIVE and self.precision == 16 and self._device_type != DeviceType.TPU: - self.scaler = self.precision_connector.backend.scaler - # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) diff --git a/setup.cfg b/setup.cfg index 8b02b462fad84..f622581b5aaf7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -169,10 +169,6 @@ ignore_errors = True [mypy-pytorch_lightning.pt_overrides.*] ignore_errors = True -# todo: add proper typing to this module... -[mypy-pytorch_lightning.plugins.legacy.*] -ignore_errors = True - # todo: add proper typing to this module... [mypy-pytorch_lightning.root_module.*] ignore_errors = True