From 533c4168a8c3d42de0e5c866b3eeb806174c6ec5 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 18 Nov 2021 17:25:28 -0800 Subject: [PATCH 01/20] 2/n Move Precision Plugin into strategy - move optimizer related logics --- pytorch_lightning/accelerators/accelerator.py | 172 +----------------- pytorch_lightning/accelerators/cpu.py | 6 +- pytorch_lightning/accelerators/gpu.py | 12 +- pytorch_lightning/accelerators/ipu.py | 13 -- pytorch_lightning/accelerators/tpu.py | 11 +- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/lite/lite.py | 2 +- pytorch_lightning/lite/wrappers.py | 6 +- .../loops/epoch/evaluation_epoch_loop.py | 2 +- .../loops/epoch/prediction_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 2 +- .../loops/optimization/optimizer_loop.py | 4 +- .../plugins/precision/apex_amp.py | 11 +- .../plugins/precision/native_amp.py | 12 +- .../plugins/precision/precision_plugin.py | 10 + .../plugins/training_type/ddp_spawn.py | 3 +- .../plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 18 +- .../plugins/training_type/fully_sharded.py | 2 +- .../plugins/training_type/horovod.py | 8 +- .../plugins/training_type/ipu.py | 13 +- .../plugins/training_type/single_device.py | 4 +- .../plugins/training_type/single_tpu.py | 16 +- .../plugins/training_type/tpu_spawn.py | 17 +- .../training_type/training_type_plugin.py | 110 ++++++++++- .../connectors/accelerator_connector.py | 2 +- .../connectors/checkpoint_connector.py | 2 +- .../trainer/connectors/data_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 18 +- tests/callbacks/test_stochastic_weight_avg.py | 6 +- tests/core/test_datamodules.py | 2 +- tests/lite/test_wrappers.py | 9 +- tests/models/test_gpu.py | 22 +-- tests/models/test_hooks.py | 2 +- .../optimization/test_manual_optimization.py | 16 +- 36 files changed, 274 insertions(+), 269 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index eb3886b209503..fea7cd2e4d367 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,23 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from abc import abstractmethod -from typing import Any, Callable, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Optional, Union import torch -from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn import Module -from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin -from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.enums import AMPType, LightningEnum +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -62,10 +54,6 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl if precision_plugin is not None: self.training_type_plugin._precision_plugin = precision_plugin - self.optimizers: List = [] - self.lr_schedulers: List = [] - self.optimizer_frequencies: List = [] - def setup_environment(self) -> None: """Setup any processes or distributed connections. @@ -80,28 +68,18 @@ def setup(self, trainer: "pl.Trainer") -> None: Args: trainer: the trainer instance """ - self.setup_training_type_plugin() - if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: - self.setup_optimizers(trainer) - self.setup_precision_plugin() + self.training_type_plugin.setup(trainer) def pre_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" - self._move_optimizer_state() + self.training_type_plugin._move_optimizer_state() self.training_type_plugin.pre_dispatch() if self.training_type_plugin.setup_optimizers_in_pre_dispatch: - self.setup_optimizers(trainer) + self.training_type_plugin.setup_optimizers(trainer) self.training_type_plugin.precision_plugin.pre_dispatch() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the GPU if needed.""" - device = device or self.root_device - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) - def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.dispatch(trainer) @@ -133,11 +111,6 @@ def lightning_module(self) -> "pl.LightningModule": """ return self.training_type_plugin.lightning_module - @property - def root_device(self) -> torch.device: - """Returns the root device.""" - return self.training_type_plugin.root_device - def teardown(self) -> None: """This method is called to teardown the training process. @@ -145,24 +118,6 @@ def teardown(self) -> None: """ self.training_type_plugin.teardown() - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: - """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just - having all tensors on the correct device. - - Args: - batch: The batch of samples to move to the correct device - device: The target device - dataloader_idx: The index of the dataloader to which the batch belongs. - """ - model = self.lightning_module - device = device or self.root_device - - if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): - # no need to transfer batch to device in DP mode - return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) - - return move_data_to_device(batch, device) - def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. @@ -195,121 +150,6 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) - def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: - """Forwards backward-calls to the precision plugin. - - Args: - closure_loss: a tensor holding the loss value to backpropagate - """ - self.training_type_plugin.pre_backward(closure_loss) - closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss) - - self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) - - closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss) - self.training_type_plugin.post_backward(closure_loss) - - return closure_loss - - def optimizer_step( - self, - optimizer: Optimizer, - opt_idx: int, - closure: Callable[[], Any], - model: Optional[Union["pl.LightningModule", Module]] = None, - **kwargs: Any, - ) -> None: - """performs the actual optimizer step. - - Args: - optimizer: the optimizer performing the step - opt_idx: index of the current optimizer - closure: closure calculating the loss value - model: reference to the model, optionally defining optimizer step related hooks - **kwargs: Any extra arguments to ``optimizer.step`` - """ - model = model or self.lightning_module - self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) - - def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: - """Zeros all model parameter's gradients.""" - model_ref = self.lightning_module - model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) - - def setup_optimizers(self, trainer: "pl.Trainer") -> None: - """Creates optimizers and schedulers. - - Args: - trainer: the Trainer, these optimizers should be connected to - """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): - return - optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( - trainer=trainer, model=self.lightning_module - ) - self.optimizers = optimizers - self.lr_schedulers = lr_schedulers - self.optimizer_frequencies = optimizer_frequencies - - def setup_training_type_plugin(self) -> None: - """Attaches the training type plugin to the accelerator.""" - self.training_type_plugin.setup() - - def setup_precision_plugin(self) -> None: - """Attaches the precision plugin to the accelerator.""" - model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect( - self.model, self.optimizers, self.lr_schedulers - ) - self.model = model - self.optimizers = optimizers - self.lr_schedulers = schedulers - - @property - def amp_backend(self) -> Optional[LightningEnum]: - if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin): - return AMPType.APEX - if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin): - return AMPType.NATIVE - return None - - @property - def precision(self) -> Union[str, int]: - """The type of precision being used with this accelerator. - - .. deprecated:: - This property been deprecated and will be removed soon. - Use ``training_type_plugin.precision_plugin.precision`` instead. - """ - rank_zero_deprecation( - f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" - f" Use `training_type_plugin.precision_plugin.precision` instead." - ) - return self.training_type_plugin.precision_plugin.precision - - @property - def scaler(self) -> Optional["GradScaler"]: - return getattr(self.training_type_plugin.precision_plugin, "scaler", None) - - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: - """Returns state of an optimizer. - - Allows for syncing/collating optimizer state from processes in custom plugins. - """ - return getattr(self.training_type_plugin, "optimizer_state", lambda x: x.state_dict())(optimizer) - - @contextlib.contextmanager - def model_sharded_context(self) -> Generator[None, None, None]: - """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. - - shard the model instantly - useful for extremely large models. Can save memory and - initialization time. - - Returns: - Model parallel context. - """ - with self.training_type_plugin.model_sharded_context(): - yield - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for a given device. diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 8b18676effb79..7d5786102d0b3 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -29,8 +29,10 @@ def setup(self, trainer: "pl.Trainer") -> None: MisconfigurationException: If the selected device is not CPU. """ - if "cpu" not in str(self.root_device): - raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.") + if "cpu" not in str(self.training_type_plugin.root_device): + raise MisconfigurationException( + f"Device should be CPU, got {self.training_type_plugin.root_device} instead." + ) return super().setup(trainer) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 62af5f27dcc1c..d72178ffd40b6 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -37,12 +37,14 @@ def setup_environment(self) -> None: If the selected device is not GPU. """ super().setup_environment() - if "cuda" not in str(self.root_device): - raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") - torch.cuda.set_device(self.root_device) + if "cuda" not in str(self.training_type_plugin.root_device): + raise MisconfigurationException( + f"Device should be GPU, got {self.training_type_plugin.root_device} instead" + ) + torch.cuda.set_device(self.training_type_plugin.root_device) def setup(self, trainer: "pl.Trainer") -> None: - self.set_nvidia_flags(trainer.local_rank) + self.set_nvidia_flags(getattr(self.training_type_plugin, "local_rank", 0)) return super().setup(trainer) def on_train_start(self) -> None: @@ -77,7 +79,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: def teardown(self) -> None: super().teardown() - self._move_optimizer_state(torch.device("cpu")) + self.training_type_plugin._move_optimizer_state(torch.device("cpu")) @staticmethod def auto_device_count() -> int: diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 0f6bdb8270395..155dce5275a9b 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -15,25 +15,12 @@ import torch -import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities.exceptions import MisconfigurationException class IPUAccelerator(Accelerator): """Accelerator for IPUs.""" - def setup_optimizers(self, trainer: "pl.Trainer") -> None: - """ - Raises: - MisconfigurationException: - If multiple optimizers are provided. - """ - super().setup_optimizers(trainer) - - if len(self.optimizers) > 1: - raise MisconfigurationException("IPUs currently only support one optimizer.") - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """IPU device stats aren't supported yet.""" return {} diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 673e8419ca7fb..f116ed7f0f493 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union import torch @@ -21,7 +21,6 @@ from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.utilities import _XLA_AVAILABLE -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm @@ -49,14 +48,6 @@ def setup(self, trainer: "pl.Trainer") -> None: ) return super().setup(trainer) - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for the given TPU device. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 89f46949a525c..1c1654e763452 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1360,7 +1360,7 @@ def training_step(...): **kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward` """ self._verify_is_manual_optimization("manual_backward") - self.trainer.accelerator.backward(loss, None, None, *args, **kwargs) + self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs) def backward( self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ecd62ab81715e..b3f49d393824f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -161,4 +161,4 @@ def closure_dis(): trainer = self._trainer assert trainer is not None with trainer.profiler.profile(profiler_action): - trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 4997d7db779e7..bbdd4fb5fab45 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -112,7 +112,7 @@ def device(self) -> torch.device: Use this to create tensors directly on the device if needed. """ - return self._accelerator.root_device + return self._strategy.root_device @property def global_rank(self) -> int: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3cd2f5eb69712..719cc7af29b87 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -50,6 +50,8 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._accelerator = accelerator + # TODO refactor to take Strategy as param, API breaking change for Lite? @ + self._strategy = self._accelerator.training_type_plugin @property def optimizer(self) -> Optimizer: @@ -57,11 +59,11 @@ def optimizer(self) -> Optimizer: def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure - self._accelerator.optimizer_step( + self._strategy.optimizer_step( self.optimizer, opt_idx=0, closure=closure, - model=self._accelerator.model, + model=self._strategy.model, ) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 102603f20302b..e802c943b56ed 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -108,7 +108,7 @@ def advance( if not self.trainer._data_connector.evaluation_data_fetcher.store_on_device: with self.trainer.profiler.profile("evaluation_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 58e65233dfe81..558b1052c4e50 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -92,7 +92,7 @@ def advance( raise StopIteration with self.trainer.profiler.profile("predict_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8ddca3ad505e8..81085ca13eb18 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -157,7 +157,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: if not self.trainer._data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch) + batch = self.trainer.training_type_plugin.batch_to_device(batch) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 7050ac75de8eb..e4a42bebb3eed 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -321,7 +321,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call return None def backward_fn(loss: Tensor) -> None: - self.trainer.accelerator.backward(loss, optimizer, opt_idx) + self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx) # check if model weights are nan if self.trainer._terminate_on_nan: @@ -403,7 +403,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, optimizer: the current optimizer opt_idx: the index of the current optimizer """ - self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) self.optim_progress.optimizer.zero_grad.increment_completed() def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index a0bbb4b9211ac..13b95be3443fc 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PARAMETERS @@ -47,13 +48,17 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS: def dispatch(self, trainer: "pl.Trainer") -> None: if not self._connected: - accelerator = trainer.accelerator - _, accelerator.optimizers = amp.initialize( - trainer.lightning_module, accelerator.optimizers, opt_level=self.amp_level + strategy = trainer.training_type_plugin + _, strategy.optimizers = amp.initialize( + trainer.lightning_module, strategy.optimizers, opt_level=self.amp_level ) self._connected = True return super().dispatch(trainer) + @property + def amp_backend(self) -> Optional[LightningEnum]: + return AMPType.APEX + def backward( self, model: "pl.LightningModule", diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index fe4a840b5337c..0867ae9d6590e 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -16,12 +16,14 @@ import torch from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: @@ -55,7 +57,15 @@ def __init__( raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") self.precision = precision self.device = device - self.scaler = scaler + self._scaler = scaler + + @property + def scaler(self) -> Optional["GradScaler"]: + return self._scaler + + @property + def amp_backend(self) -> Optional[LightningEnum]: + return AMPType.NATIVE def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: if self.scaler is not None: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c4969c9cc805f..af3feee9e04ee 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -17,12 +17,14 @@ import torch from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.types import _PARAMETERS @@ -48,6 +50,14 @@ def connect( """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers + @property + def scaler(self) -> Optional["GradScaler"]: + return None + + @property + def amp_backend(self) -> Optional[LightningEnum]: + return None + def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: """Run before precision plugin executes backward. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index da724944ade7e..b958d4808e1c2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -125,11 +125,12 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self): return True - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() + super().setup(trainer) def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 01959bdcee212..228233ac67854 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -517,7 +517,7 @@ def _initialize_deepspeed_train(self, model): self.model = model @contextlib.contextmanager - def model_sharded_context(self) -> Generator[None, None, None]: + def model_sharded_context(self) -> Generator: if self.zero_stage_3: assert self._config_initialized dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 3f1b9a3acfa50..ad4dd1497db61 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Any, List, Optional import torch from torch.nn import DataParallel, Module +import pytorch_lightning as pl from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -61,10 +62,11 @@ def node_rank(self) -> int: def world_size(self) -> int: return 1 - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() self._model = self._setup_model(LightningParallelModule(self._model)) + super().setup(trainer) def _setup_model(self, model: Module) -> DataParallel: """Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module.""" @@ -95,6 +97,18 @@ def root_device(self): def model_to_device(self) -> None: self._model.to(self.root_device) + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just + having all tensors on the correct device. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + dataloader_idx: The index of the dataloader to which the batch belongs. + """ + model = self.lightning_module + return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) + def barrier(self, *args, **kwargs): pass diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 73ea87b05835e..ab60f7a44665a 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -157,7 +157,7 @@ def configure_ddp(self) -> None: self.model_to_device() # setup optimizers after fully sharded has wrapped the lightning module - self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) + self.setup_optimizers(self.lightning_module.trainer) def pre_dispatch(self) -> None: if self.sync_batchnorm: diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 961d2764b8ef3..4aef238abb5db 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -19,6 +19,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -73,8 +74,9 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() + super().setup(trainer) def pre_dispatch(self): @@ -85,7 +87,7 @@ def pre_dispatch(self): def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt - optimizers = self.lightning_module.trainer.optimizers + optimizers = self.optimizers optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers] # Horovod: scale the learning rate by the number of workers to account for @@ -106,7 +108,7 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) - self.lightning_module.trainer.accelerator.optimizers = self._wrap_optimizers(optimizers) + self.optimizers = self._wrap_optimizers(optimizers) def start_training(self, trainer): with ExitStack() as stack: diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index ef9b3d1f02b82..e86bd404fa429 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -110,7 +110,7 @@ def __init__( options["autoReport.directory"] = self.autoreport_dir os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options) - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: # set the `accumulate_grad_batches` property as early as possible self._handle_gradient_accumulation_steps() @@ -119,6 +119,15 @@ def setup(self) -> None: # to use the simpler solution before adding abstractions to override the `DataLoader` class self._update_dataloader_original = pl.trainer.data_loading._update_dataloader pl.trainer.data_loading._update_dataloader = self._convert_to_poptorch_loader + super().setup(trainer) + + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + # refactor after move accelerator into strategy @four4fish + # RFC: I think set optimizer related logic should be in strategy instead of accelerator. + if len(self.optimizers) > 1: + raise MisconfigurationException("IPUs currently only support one optimizer.") + + super().setup_optimizers(trainer) def pre_dispatch(self) -> None: model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) @@ -314,7 +323,7 @@ def on_predict_end(self): def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: # Updates optimizer stats if LR scheduler modified the optimizer state - optimizer = self.lightning_module.trainer.optimizers[0] + optimizer = self.optimizers[0] self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer) @property diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 12a0f625b64fc..9dde35a589e05 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,6 +15,7 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin @@ -69,8 +70,9 @@ def root_device(self) -> torch.device: def model_to_device(self) -> None: self._model.to(self.root_device) - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() + super().setup(trainer) @property def is_global_zero(self) -> bool: diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index e6f6a5f4b26f2..c987c9c732c66 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -14,11 +14,15 @@ import os from typing import Any, Dict, Optional +import torch + +import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH @@ -50,7 +54,7 @@ def __init__( def is_distributed(self) -> bool: return False - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): @@ -58,6 +62,16 @@ def setup(self) -> None: else: set_shared_parameters(self.model, shared_params) + super().setup(trainer) + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the TPU if needed.""" + # TODO: `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) + def model_to_device(self) -> None: self.model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 3ab9a8171aac5..f47d766d52cb8 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -32,7 +32,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters -from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -121,8 +121,17 @@ def pre_dispatch(self): if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: self.create_mp_queue() + super().setup(trainer) + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the TPU if needed.""" + # TODO: `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) def _setup_model(self, model: Module) -> Module: return model @@ -170,8 +179,8 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: else: set_shared_parameters(self.model.module, shared_params) - trainer.accelerator.setup_optimizers(trainer) - self.precision_plugin.connect(self._model, None, None) + trainer.training_type_plugin.setup_optimizers(trainer) + trainer.precision_plugin.connect(self._model, None, None) self.barrier("pre-run-stage") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7010c0e878dc9..59ee36ae4ae93 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor @@ -26,6 +26,8 @@ from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -42,6 +44,9 @@ def __init__( checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() + self.optimizers: List = [] + self.lr_schedulers: List = [] + self.optimizer_frequencies: List = [] @property def checkpoint_io(self) -> CheckpointIO: @@ -66,8 +71,105 @@ def setup_environment(self) -> None: environment before setup is complete. """ - def setup(self) -> None: - """Called by the accelerator to finish setup.""" + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + """Creates optimizers and schedulers. + + Args: + trainer: the Trainer, these optimizers should be connected to + """ + if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + return + optimizers, lr_schedulers, optimizer_frequencies = self.init_optimizers( + trainer=trainer, model=self.lightning_module + ) + self.optimizers = optimizers + self.lr_schedulers = lr_schedulers + self.optimizer_frequencies = optimizer_frequencies + + def setup(self, trainer: "pl.Trainer") -> None: + """Setup plugins for the trainer fit and creates optimizers. + + Args: + trainer: the trainer instance + """ + # call super() + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + + def setup_precision_plugin(self) -> None: + """Attaches the precision plugin to the accelerator.""" + model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) + self.model = model + self.optimizers = optimizers + self.lr_schedulers = schedulers + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the GPU if needed.""" + device = device or self.root_device + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) + + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + """Returns state of an optimizer. + + Allows for syncing/collating optimizer state from processes in custom plugins. + """ + return optimizer.state_dict() + + def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: + """Forwards backward-calls to the precision plugin. + + Args: + closure_loss: a tensor holding the loss value to backpropagate + """ + self.pre_backward(closure_loss) + closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + + self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + + closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + self.post_backward(closure_loss) + + return closure_loss + + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any, + ) -> None: + """performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + opt_idx: index of the current optimizer + closure: closure calculating the loss value + model: reference to the model, optionally defining optimizer step related hooks + **kwargs: Any extra arguments to ``optimizer.step`` + """ + model = model or self.lightning_module + self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + + def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: + """Zeros all model parameter's gradients.""" + model_ref = self.lightning_module + model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) + + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just + having all tensors on the correct device. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + dataloader_idx: The index of the dataloader to which the batch belongs. + """ + device = device or self.root_device + return move_data_to_device(batch, device) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. @@ -202,7 +304,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: optimizer_states = checkpoint["optimizer_states"] - for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): + for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) def start_training(self, trainer: "pl.Trainer") -> None: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7136437bbc69d..cb0e4af502a5f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -584,7 +584,7 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: @property def root_gpu(self) -> Optional[int]: return ( - self.accelerator.root_device.index + self.training_type_plugin.root_device.index if not isinstance(self.accelerator, (IPUAccelerator, TPUAccelerator)) else None ) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ab0d3aa4288fa..92cad3b118006 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -382,7 +382,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): # Rely on accelerator to dump optimizer state - optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) + optimizer_state = self.trainer.training_type_plugin.optimizer_state(optimizer) optimizer_states.append(optimizer_state) checkpoint["optimizer_states"] = optimizer_states diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index de81060ba1f80..080d3d94402f4 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -116,7 +116,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) data_fetcher.setup( dataloader, stage=stage, - batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), + batch_to_device=partial(self.trainer.training_type_plugin.batch_to_device, dataloader_idx=dataloader_idx), profiler=self.trainer.profiler, ) setattr(self, f"{stage}_data_fetcher", data_fetcher) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1ccdb9ecaeca8..c4e1d1a551567 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1393,7 +1393,7 @@ def _call_setup_hook(self) -> None: self.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: - with self.accelerator.model_sharded_context(): + with self.training_type_plugin.model_sharded_context(): self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") @@ -1635,7 +1635,7 @@ def lightning_module(self) -> "pl.LightningModule": @property def optimizers(self) -> List[Optimizer]: - return self.accelerator.optimizers + return self.training_type_plugin.optimizers @optimizers.setter def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: @@ -1644,27 +1644,27 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: # the `lightning_optimizers` trainer property self._lightning_optimizers = None - self.accelerator.optimizers = new_optims + self.training_type_plugin.optimizers = new_optims @property def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: - return self.accelerator.lr_schedulers + return self.training_type_plugin.lr_schedulers @lr_schedulers.setter def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: - self.accelerator.lr_schedulers = new_schedulers + self.training_type_plugin.lr_schedulers = new_schedulers @property def optimizer_frequencies(self) -> list: - return self.accelerator.optimizer_frequencies + return self.training_type_plugin.optimizer_frequencies @optimizer_frequencies.setter def optimizer_frequencies(self, new_freqs: list) -> None: - self.accelerator.optimizer_frequencies = new_freqs + self.training_type_plugin.optimizer_frequencies = new_freqs @property def amp_backend(self) -> Optional[str]: - return self.accelerator.amp_backend + return self.precision_plugin.amp_backend @property def precision(self) -> Union[str, int]: @@ -1672,7 +1672,7 @@ def precision(self) -> Union[str, int]: @property def scaler(self): - return self.accelerator.scaler + return self.precision_plugin.scaler @property def gpus(self) -> Optional[Union[List[int], str, int]]: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index d30edb177ed10..83f769813cdc6 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -21,9 +21,9 @@ from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -101,7 +101,7 @@ def on_train_end(self, trainer, pl_module): if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin): # check backward call count. the batchnorm update epoch should not backward - assert trainer.accelerator.backward.call_count == trainer.max_epochs * trainer.limit_train_batches + assert trainer.training_type_plugin.backward.call_count == trainer.max_epochs * trainer.limit_train_batches # check call counts assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) @@ -131,7 +131,7 @@ def train_with_swa( num_processes=num_processes, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward): trainer.fit(model) # check the model is the expected diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index d35941ac2cb15..da2aa58e2f041 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -306,7 +306,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer - batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) + batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index c271d3b3163ed..8cda56485b9d8 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -17,6 +17,7 @@ import torch from torch.utils.data.dataloader import DataLoader +from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -145,8 +146,10 @@ def test_lite_optimizer_wraps(): def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() - accelerator = Mock() + strategy = Mock() + accelerator = Accelerator(strategy) lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) lite_optimizer.step() - accelerator.optimizer_step.assert_called_once() - accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) + strategy = accelerator.training_type_plugin + strategy.optimizer_step.assert_called_once() + strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index bf9dab47a71aa..9e0e67200c38f 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -252,35 +252,35 @@ def test_single_gpu_batch_parse(): # non-transferrable types primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] for batch in primitive_objects: - data = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + data = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert data == batch # batch is just a tensor batch = torch.rand(2, 3) - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch.device.index == 0 and batch.type() == "torch.cuda.FloatTensor" # tensor list batch = [torch.rand(2, 3), torch.rand(2, 3)] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].device.index == 0 and batch[0].type() == "torch.cuda.FloatTensor" assert batch[1].device.index == 0 and batch[1].type() == "torch.cuda.FloatTensor" # tensor list of lists batch = [[torch.rand(2, 3), torch.rand(2, 3)]] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[0][1].device.index == 0 and batch[0][1].type() == "torch.cuda.FloatTensor" # tensor dict batch = [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)}] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0]["a"].device.index == 0 and batch[0]["a"].type() == "torch.cuda.FloatTensor" assert batch[0]["b"].device.index == 0 and batch[0]["b"].type() == "torch.cuda.FloatTensor" # tuple of tensor list and list of tensor dict batch = ([torch.rand(2, 3) for _ in range(2)], [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)} for _ in range(2)]) - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[1][0]["a"].device.index == 0 @@ -292,7 +292,7 @@ def test_single_gpu_batch_parse(): # namedtuple of tensor BatchType = namedtuple("BatchType", ["a", "b"]) batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].a.device.index == 0 assert batch[0].a.type() == "torch.cuda.FloatTensor" @@ -305,7 +305,7 @@ def to(self, *args, **kwargs): self.a = self.a.to(*args, **kwargs) return self - batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(CustomBatchType(), torch.device("cuda:0")) assert batch.a.type() == "torch.cuda.FloatTensor" # torchtext.data.Batch @@ -326,7 +326,7 @@ def to(self, *args, **kwargs): label_field.build_vocab(dataset) batch = Batch(data=examples, dataset=dataset) - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch.text.type() == "torch.cuda.LongTensor" assert batch.label.type() == "torch.cuda.LongTensor" @@ -339,7 +339,7 @@ def test_non_blocking(): batch = torch.zeros(2, 3) with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0), non_blocking=True) class BatchObject: @@ -348,5 +348,5 @@ def to(self, *args, **kwargs): batch = BatchObject() with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0)) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 35b50acfcef4f..c2abe17d35298 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -157,7 +157,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): # running .fit() would require us to implement custom data loaders, we mock the model reference instead model_getter_mock.return_value = model - batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) + batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) assert model.on_before_batch_transfer_hook_rank == 0 assert model.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index ba4fe915fadb1..1ddae5ee3392f 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.plugins.training_type import TrainingTypePlugin from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -128,7 +128,7 @@ def on_train_end(self): ) scaler_step = scaler_step_patch.start() - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -162,7 +162,7 @@ def training_epoch_end(self, outputs) -> None: enable_model_summary=False, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -189,7 +189,7 @@ def training_epoch_end(self, outputs) -> None: enable_model_summary=False, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 assert set(trainer.logged_metrics) == {"a_step", "a_epoch"} @@ -212,7 +212,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): gpus=1, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -470,7 +470,7 @@ def log_grad_norm(self, grad_norm_dict): track_grad_norm=2, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -540,7 +540,7 @@ def configure_optimizers(self): log_every_n_steps=1, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2 assert trainer.progress_bar_metrics["train_loss_step"] == model._losses[-1] @@ -596,7 +596,7 @@ def configure_optimizers(self): log_every_n_steps=1, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2 From f7e1f87462eea436f6d84a7522f5ab5cf07df425 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 19 Nov 2021 14:12:35 -0800 Subject: [PATCH 02/20] correct batch_to_device logic --- pytorch_lightning/plugins/training_type/dp.py | 6 +++--- .../plugins/training_type/training_type_plugin.py | 3 ++- tests/lite/test_wrappers.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index ad4dd1497db61..314f4f442f50c 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -106,8 +106,8 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ - model = self.lightning_module - return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) + device = device or self.root_device + return move_data_to_device(batch, device) def barrier(self, *args, **kwargs): pass diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 59ee36ae4ae93..b0e98043c8c1d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -168,8 +168,9 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ + model = self.lightning_module device = device or self.root_device - return move_data_to_device(batch, device) + return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 8cda56485b9d8..e14c1fa25bddf 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -147,7 +147,7 @@ def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() strategy = Mock() - accelerator = Accelerator(strategy) + accelerator = Accelerator(None, strategy) lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) lite_optimizer.step() strategy = accelerator.training_type_plugin From b6e2ac7e0eb3a630716141dfbdd6fa68c1d0281f Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Fri, 19 Nov 2021 10:01:40 -0800 Subject: [PATCH 03/20] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 719cc7af29b87..ff0e126a07fa5 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -50,7 +50,7 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._accelerator = accelerator - # TODO refactor to take Strategy as param, API breaking change for Lite? @ + # TODO (@awaelchli) refactor to take Strategy as param self._strategy = self._accelerator.training_type_plugin @property From 110a7baf16921a7d2e9dfecfbc03c01c04b98602 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 19 Nov 2021 15:11:59 -0800 Subject: [PATCH 04/20] fix tpu/ipu setup overrides --- pytorch_lightning/accelerators/gpu.py | 2 +- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/plugins/training_type/ipu.py | 7 ++++--- pytorch_lightning/plugins/training_type/single_tpu.py | 5 ++++- pytorch_lightning/plugins/training_type/tpu_spawn.py | 5 ++++- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index d72178ffd40b6..49d0770e54ff0 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -44,7 +44,7 @@ def setup_environment(self) -> None: torch.cuda.set_device(self.training_type_plugin.root_device) def setup(self, trainer: "pl.Trainer") -> None: - self.set_nvidia_flags(getattr(self.training_type_plugin, "local_rank", 0)) + self.set_nvidia_flags(trainer.local_rank) return super().setup(trainer) def on_train_start(self) -> None: diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 228233ac67854..01959bdcee212 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -517,7 +517,7 @@ def _initialize_deepspeed_train(self, model): self.model = model @contextlib.contextmanager - def model_sharded_context(self) -> Generator: + def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index e86bd404fa429..03a672877e379 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -119,11 +119,12 @@ def setup(self, trainer: "pl.Trainer") -> None: # to use the simpler solution before adding abstractions to override the `DataLoader` class self._update_dataloader_original = pl.trainer.data_loading._update_dataloader pl.trainer.data_loading._update_dataloader = self._convert_to_poptorch_loader - super().setup(trainer) + + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() def setup_optimizers(self, trainer: "pl.Trainer") -> None: - # refactor after move accelerator into strategy @four4fish - # RFC: I think set optimizer related logic should be in strategy instead of accelerator. if len(self.optimizers) > 1: raise MisconfigurationException("IPUs currently only support one optimizer.") diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index c987c9c732c66..3860b20d2fe99 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -55,6 +55,7 @@ def is_distributed(self) -> bool: return False def setup(self, trainer: "pl.Trainer") -> None: + # Revisit strategy inheritance. shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): @@ -62,7 +63,9 @@ def setup(self, trainer: "pl.Trainer") -> None: else: set_shared_parameters(self.model, shared_params) - super().setup(trainer) + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the TPU if needed.""" diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index f47d766d52cb8..ef46b96e18c72 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -122,8 +122,11 @@ def pre_dispatch(self): os.environ["PT_XLA_DEBUG"] = str(1) def setup(self, trainer: "pl.Trainer") -> None: + # Revisit strategy inheritance self.create_mp_queue() - super().setup(trainer) + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the TPU if needed.""" From ac1f49cc4f65c55479034a6893ac1cbe749779c3 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 19 Nov 2021 15:50:10 -0800 Subject: [PATCH 05/20] remove batch_to_device change --- pytorch_lightning/accelerators/accelerator.py | 22 ++++++++++++++++++- .../loops/epoch/evaluation_epoch_loop.py | 2 +- .../loops/epoch/prediction_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 16 ++------------ .../training_type/training_type_plugin.py | 13 ----------- .../trainer/connectors/data_connector.py | 2 +- tests/core/test_datamodules.py | 2 +- tests/models/test_gpu.py | 22 +++++++++---------- tests/models/test_hooks.py | 2 +- 10 files changed, 40 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index fea7cd2e4d367..1bfb12547473c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -19,7 +19,8 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -118,6 +119,25 @@ def teardown(self) -> None: """ self.training_type_plugin.teardown() + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + """Moves the batch to the correct device. + + The returned batch is of the same type as the input batch, just + having all tensors on the correct device. + Args: + batch: The batch of samples to move to the correct device + device: The target device + dataloader_idx: The index of the dataloader to which the batch belongs. + """ + model = self.lightning_module + device = device or self.root_device + + if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): + # no need to transfer batch to device in DP mode + return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) + + return move_data_to_device(batch, device) + def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index e802c943b56ed..102603f20302b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -108,7 +108,7 @@ def advance( if not self.trainer._data_connector.evaluation_data_fetcher.store_on_device: with self.trainer.profiler.profile("evaluation_batch_to_device"): - batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 558b1052c4e50..58e65233dfe81 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -92,7 +92,7 @@ def advance( raise StopIteration with self.trainer.profiler.profile("predict_batch_to_device"): - batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 81085ca13eb18..8ddca3ad505e8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -157,7 +157,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: if not self.trainer._data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.training_type_plugin.batch_to_device(batch) + batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 314f4f442f50c..423e79d9bd83f 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import List, Optional import torch from torch.nn import DataParallel, Module @@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -97,18 +97,6 @@ def root_device(self): def model_to_device(self) -> None: self._model.to(self.root_device) - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: - """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just - having all tensors on the correct device. - - Args: - batch: The batch of samples to move to the correct device - device: The target device - dataloader_idx: The index of the dataloader to which the batch belongs. - """ - device = device or self.root_device - return move_data_to_device(batch, device) - def barrier(self, *args, **kwargs): pass diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b0e98043c8c1d..c32da4e76718b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -159,19 +159,6 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt model_ref = self.lightning_module model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: - """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just - having all tensors on the correct device. - - Args: - batch: The batch of samples to move to the correct device - device: The target device - dataloader_idx: The index of the dataloader to which the batch belongs. - """ - model = self.lightning_module - device = device or self.root_device - return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) - def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 080d3d94402f4..de81060ba1f80 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -116,7 +116,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) data_fetcher.setup( dataloader, stage=stage, - batch_to_device=partial(self.trainer.training_type_plugin.batch_to_device, dataloader_idx=dataloader_idx), + batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), profiler=self.trainer.profiler, ) setattr(self, f"{stage}_data_fetcher", data_fetcher) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index da2aa58e2f041..d35941ac2cb15 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -306,7 +306,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer - batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) + batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 9e0e67200c38f..bf9dab47a71aa 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -252,35 +252,35 @@ def test_single_gpu_batch_parse(): # non-transferrable types primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] for batch in primitive_objects: - data = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + data = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert data == batch # batch is just a tensor batch = torch.rand(2, 3) - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch.device.index == 0 and batch.type() == "torch.cuda.FloatTensor" # tensor list batch = [torch.rand(2, 3), torch.rand(2, 3)] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].device.index == 0 and batch[0].type() == "torch.cuda.FloatTensor" assert batch[1].device.index == 0 and batch[1].type() == "torch.cuda.FloatTensor" # tensor list of lists batch = [[torch.rand(2, 3), torch.rand(2, 3)]] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[0][1].device.index == 0 and batch[0][1].type() == "torch.cuda.FloatTensor" # tensor dict batch = [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)}] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch[0]["a"].device.index == 0 and batch[0]["a"].type() == "torch.cuda.FloatTensor" assert batch[0]["b"].device.index == 0 and batch[0]["b"].type() == "torch.cuda.FloatTensor" # tuple of tensor list and list of tensor dict batch = ([torch.rand(2, 3) for _ in range(2)], [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)} for _ in range(2)]) - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[1][0]["a"].device.index == 0 @@ -292,7 +292,7 @@ def test_single_gpu_batch_parse(): # namedtuple of tensor BatchType = namedtuple("BatchType", ["a", "b"]) batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].a.device.index == 0 assert batch[0].a.type() == "torch.cuda.FloatTensor" @@ -305,7 +305,7 @@ def to(self, *args, **kwargs): self.a = self.a.to(*args, **kwargs) return self - batch = trainer.training_type_plugin.batch_to_device(CustomBatchType(), torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device("cuda:0")) assert batch.a.type() == "torch.cuda.FloatTensor" # torchtext.data.Batch @@ -326,7 +326,7 @@ def to(self, *args, **kwargs): label_field.build_vocab(dataset) batch = Batch(data=examples, dataset=dataset) - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) assert batch.text.type() == "torch.cuda.LongTensor" assert batch.label.type() == "torch.cuda.LongTensor" @@ -339,7 +339,7 @@ def test_non_blocking(): batch = torch.zeros(2, 3) with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0), non_blocking=True) class BatchObject: @@ -348,5 +348,5 @@ def to(self, *args, **kwargs): batch = BatchObject() with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0)) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index c2abe17d35298..35b50acfcef4f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -157,7 +157,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): # running .fit() would require us to implement custom data loaders, we mock the model reference instead model_getter_mock.return_value = model - batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) + batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) assert model.on_before_batch_transfer_hook_rank == 0 assert model.transfer_batch_to_device_hook_rank == 1 From 818b4e76200ebab9e1c73281e98cc25b8911a0e5 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 19 Nov 2021 16:01:43 -0800 Subject: [PATCH 06/20] remove batch_to_device change --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1bfb12547473c..b65012b7cc901 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -130,7 +130,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat dataloader_idx: The index of the dataloader to which the batch belongs. """ model = self.lightning_module - device = device or self.root_device + device = device or self.training_type_plugin.root_device if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): # no need to transfer batch to device in DP mode From 7b84fdea551a8e6e4ab21555f271efd9e4495285 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 19 Nov 2021 16:57:07 -0800 Subject: [PATCH 07/20] add changelog --- CHANGELOG.md | 6 ++---- pytorch_lightning/plugins/training_type/ipu.py | 4 ++-- pytorch_lightning/plugins/training_type/sharded_spawn.py | 2 +- .../plugins/training_type/training_type_plugin.py | 1 - 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index adb1b070dc386..dddc5ed55ab9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,10 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) -- - - -- +- Moved optimizer related logics from `Accelerator` to `TrainingTypePlugin` ([#10596](https://github.com/PyTorchLightning/pytorch-lightning/pull/10596)) +>>>>>>> 66d4cec7c (add changelog) - diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 03a672877e379..64a617dd66c20 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -125,11 +125,11 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() def setup_optimizers(self, trainer: "pl.Trainer") -> None: + super().setup_optimizers(trainer) + if len(self.optimizers) > 1: raise MisconfigurationException("IPUs currently only support one optimizer.") - super().setup_optimizers(trainer) - def pre_dispatch(self) -> None: model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 12c06b9dde541..9f6b5f746ad2c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -119,7 +119,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): - self.precision_plugin.scaler = ShardedGradScaler() + self._precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c32da4e76718b..b05d80795ddf0 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -92,7 +92,6 @@ def setup(self, trainer: "pl.Trainer") -> None: Args: trainer: the trainer instance """ - # call super() if not self.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) self.setup_precision_plugin() From ad8442bfa9039581b3021918a3101be7b0d5fb99 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 19 Nov 2021 17:17:05 -0800 Subject: [PATCH 08/20] address comment about amp_backend --- pytorch_lightning/plugins/precision/apex_amp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 13b95be3443fc..f62f9d4f9638d 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -43,6 +43,10 @@ def __init__(self, amp_level: str = "O2") -> None: self.amp_level = amp_level self._connected = False + @property + def amp_backend(self) -> Optional[LightningEnum]: + return self.backend + def main_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) @@ -55,10 +59,6 @@ def dispatch(self, trainer: "pl.Trainer") -> None: self._connected = True return super().dispatch(trainer) - @property - def amp_backend(self) -> Optional[LightningEnum]: - return AMPType.APEX - def backward( self, model: "pl.LightningModule", From 3c2757e07d7185655bad896eea93d56214ccb216 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Mon, 22 Nov 2021 20:04:59 -0800 Subject: [PATCH 09/20] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Co-authored-by: Carlos Mocholí Co-authored-by: thomas chaton --- CHANGELOG.md | 1 - pytorch_lightning/accelerators/accelerator.py | 1 + pytorch_lightning/plugins/training_type/sharded_spawn.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dddc5ed55ab9a..89c204ca16223 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved optimizer related logics from `Accelerator` to `TrainingTypePlugin` ([#10596](https://github.com/PyTorchLightning/pytorch-lightning/pull/10596)) ->>>>>>> 66d4cec7c (add changelog) - diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index b65012b7cc901..8370ae42aa3a8 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -124,6 +124,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat The returned batch is of the same type as the input batch, just having all tensors on the correct device. + Args: batch: The batch of samples to move to the correct device device: The target device diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 9f6b5f746ad2c..91e49e3bdffd3 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -119,7 +119,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): - self._precision_plugin.scaler = ShardedGradScaler() + self._precision_plugin._scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod From e8f3773847b679f1be9801800d1b716bc71cae98 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Mon, 22 Nov 2021 20:05:40 -0800 Subject: [PATCH 10/20] Update native_amp.py --- pytorch_lightning/plugins/precision/native_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 0867ae9d6590e..6fa58370e6fb5 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -65,7 +65,7 @@ def scaler(self) -> Optional["GradScaler"]: @property def amp_backend(self) -> Optional[LightningEnum]: - return AMPType.NATIVE + return backend def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: if self.scaler is not None: From bb98ac4422d8527a00ee132c5cf6bf0c56a805e5 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:47:31 -0800 Subject: [PATCH 11/20] Update native_amp.py --- pytorch_lightning/plugins/precision/native_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 6fa58370e6fb5..02f67e1db807b 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -65,7 +65,7 @@ def scaler(self) -> Optional["GradScaler"]: @property def amp_backend(self) -> Optional[LightningEnum]: - return backend + return self.backend def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: if self.scaler is not None: From 6f6b66d113ea7fb568e11ab08a547f4c7bfca303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 26 Nov 2021 14:33:04 +0100 Subject: [PATCH 12/20] add back Accelerator's root device --- pytorch_lightning/accelerators/accelerator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8370ae42aa3a8..7b02028b37999 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -112,6 +112,11 @@ def lightning_module(self) -> "pl.LightningModule": """ return self.training_type_plugin.lightning_module + @property + def root_device(self) -> torch.device: + """Returns the root device.""" + return self.training_type_plugin.root_device + def teardown(self) -> None: """This method is called to teardown the training process. From 6f3a8204ff7188b23a4f7444f9da4f2d5a6d3f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 26 Nov 2021 14:35:14 +0100 Subject: [PATCH 13/20] remove a comment --- pytorch_lightning/plugins/training_type/single_tpu.py | 1 - pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 3860b20d2fe99..f9fa415e67090 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -55,7 +55,6 @@ def is_distributed(self) -> bool: return False def setup(self, trainer: "pl.Trainer") -> None: - # Revisit strategy inheritance. shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ef46b96e18c72..24d48c1aac459 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -122,7 +122,6 @@ def pre_dispatch(self): os.environ["PT_XLA_DEBUG"] = str(1) def setup(self, trainer: "pl.Trainer") -> None: - # Revisit strategy inheritance self.create_mp_queue() if not self.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) From 12bfa2702cff1fafd0d14bf33f711308b3f8c0db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 26 Nov 2021 14:41:36 +0100 Subject: [PATCH 14/20] improve typing for attributes --- .../plugins/training_type/training_type_plugin.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b05d80795ddf0..96666719b3fb3 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -19,6 +19,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -44,9 +45,9 @@ def __init__( checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() - self.optimizers: List = [] - self.lr_schedulers: List = [] - self.optimizer_frequencies: List = [] + self.optimizers: List[Optimizer] = [] + self.lr_schedulers: List[_LRScheduler] = [] + self.optimizer_frequencies: List[int] = [] @property def checkpoint_io(self) -> CheckpointIO: From 95d11f323345ea3284167597b6341735be809617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 26 Nov 2021 14:45:17 +0100 Subject: [PATCH 15/20] keep model shard context method in accelerator --- pytorch_lightning/accelerators/accelerator.py | 14 +++++++++++++- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7b02028b37999..6844e4a325810 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from abc import abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Generator import torch from torch.nn import Module @@ -176,6 +177,17 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) + @contextlib.contextmanager + def model_sharded_context(self) -> Generator[None, None, None]: + """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. + shard the model instantly - useful for extremely large models. Can save memory and + initialization time. + Returns: + Model parallel context. + """ + with self.training_type_plugin.model_sharded_context(): + yield + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for a given device. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c4e1d1a551567..90b356a55af96 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1393,7 +1393,7 @@ def _call_setup_hook(self) -> None: self.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: - with self.training_type_plugin.model_sharded_context(): + with self.accelerator.model_sharded_context(): self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") From 938dc5e426b7fd345fd93dd4d83b480ab5d4a3a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Nov 2021 13:46:41 +0000 Subject: [PATCH 16/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/accelerator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 6844e4a325810..e6a1a3006f805 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import abstractmethod -from typing import Any, Dict, Optional, Union, Generator +from typing import Any, Dict, Generator, Optional, Union import torch from torch.nn import Module @@ -180,6 +180,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. + shard the model instantly - useful for extremely large models. Can save memory and initialization time. Returns: From 8ce5ac6302a738a1274289da86be35192fbc5565 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 18:11:25 +0100 Subject: [PATCH 17/20] remove scaler and amp_backend properties from public Precision interface --- pytorch_lightning/plugins/precision/apex_amp.py | 5 ----- pytorch_lightning/plugins/precision/native_amp.py | 12 +----------- .../plugins/precision/precision_plugin.py | 9 --------- .../plugins/training_type/sharded_spawn.py | 2 +- pytorch_lightning/trainer/trainer.py | 6 +++--- 5 files changed, 5 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index f62f9d4f9638d..1e448a226a2a1 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -20,7 +20,6 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType -from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PARAMETERS @@ -43,10 +42,6 @@ def __init__(self, amp_level: str = "O2") -> None: self.amp_level = amp_level self._connected = False - @property - def amp_backend(self) -> Optional[LightningEnum]: - return self.backend - def main_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 32225119fdc3c..f6cb28c76c867 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -16,14 +16,12 @@ import torch from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType -from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: @@ -57,15 +55,7 @@ def __init__( raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") self.precision = precision self.device = device - self._scaler = scaler - - @property - def scaler(self) -> Optional["GradScaler"]: - return self._scaler - - @property - def amp_backend(self) -> Optional[LightningEnum]: - return self.backend + self.scaler = scaler def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: if self.scaler is not None: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 7e60341a3671d..140e4e3af05ee 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -17,7 +17,6 @@ import torch from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn import Module from torch.optim import Optimizer @@ -50,14 +49,6 @@ def connect( """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers - @property - def scaler(self) -> Optional["GradScaler"]: - return None - - @property - def amp_backend(self) -> Optional[LightningEnum]: - return None - def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: """Run before precision plugin executes backward. diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 91e49e3bdffd3..9f6b5f746ad2c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -119,7 +119,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): - self._precision_plugin._scaler = ShardedGradScaler() + self._precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b573373b94090..f58f1d164c784 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1673,15 +1673,15 @@ def optimizer_frequencies(self, new_freqs: list) -> None: @property def amp_backend(self) -> Optional[str]: - return self.precision_plugin.amp_backend + return getattr(self.precision_plugin, "backend") @property def precision(self) -> Union[str, int]: return self.training_type_plugin.precision_plugin.precision @property - def scaler(self): - return self.precision_plugin.scaler + def scaler(self) -> Optional[Any]: + return getattr(self.precision_plugin, "scaler", None) @property def gpus(self) -> Optional[Union[List[int], str, int]]: From 15b6b4408847fc344a8c2b54a93c820635ec4a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 18:29:33 +0100 Subject: [PATCH 18/20] make amp_backend property equivalent to how it was in accelerator --- pytorch_lightning/trainer/trainer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f58f1d164c784..c19f49e84726a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,15 @@ from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.plugins import ( + DDPSpawnPlugin, + ParallelPlugin, + PLUGIN_INPUT, + PrecisionPlugin, + TrainingTypePlugin, + ApexMixedPrecisionPlugin, + NativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -73,6 +81,7 @@ rank_zero_deprecation, rank_zero_info, rank_zero_warn, + AMPType, ) from pytorch_lightning.utilities.argparse import ( _defaults_from_env_vars, @@ -1672,8 +1681,12 @@ def optimizer_frequencies(self, new_freqs: list) -> None: self.training_type_plugin.optimizer_frequencies = new_freqs @property - def amp_backend(self) -> Optional[str]: - return getattr(self.precision_plugin, "backend") + def amp_backend(self) -> Optional[AMPType]: + if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + return AMPType.APEX + if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + return AMPType.NATIVE + return None @property def precision(self) -> Union[str, int]: From f96c1e8babf0f86c92e9724a38e83cb892c8dddd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Nov 2021 17:30:58 +0000 Subject: [PATCH 19/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c19f49e84726a..26f03bef32db0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -39,13 +39,13 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import ( + ApexMixedPrecisionPlugin, DDPSpawnPlugin, + NativeMixedPrecisionPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin, - ApexMixedPrecisionPlugin, - NativeMixedPrecisionPlugin, ) from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.profiler import ( @@ -75,13 +75,13 @@ _IPU_AVAILABLE, _StrategyType, _TPU_AVAILABLE, + AMPType, device_parser, GradClipAlgorithmType, parsing, rank_zero_deprecation, rank_zero_info, rank_zero_warn, - AMPType, ) from pytorch_lightning.utilities.argparse import ( _defaults_from_env_vars, From 26d3a7dd301d1c4a0e79870cacb90379d18efe45 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Mon, 29 Nov 2021 11:15:28 -0800 Subject: [PATCH 20/20] remove unused import --- pytorch_lightning/plugins/precision/precision_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 140e4e3af05ee..3c02d198abd3c 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -23,7 +23,6 @@ import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType -from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.types import _PARAMETERS