diff --git a/CHANGELOG.md b/CHANGELOG.md index 45389685645a9..73f6adabda3a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,7 +233,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787)) -- Use a unique filename to save temp ckpt in tuner ([#96827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682)) +- Directly call `TrainingTypePlugin` collective APIs instead of going through the Accelerator ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677)) + + +- Use a unique filename to save temp ckpt in tuner ([#9682](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682)) ### Deprecated @@ -283,6 +286,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor in favor of adding the `StochasticWeightAveraging` callback directly to the list of callbacks ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989)) +- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ae76fefb8db6e..89e52758f5f90 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT @@ -339,21 +339,42 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.lightning_module_state_dict() def barrier(self, name: Optional[str] = None) -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.barrier`` directly. + """ + rank_zero_deprecation( + "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " + "Barrier logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.broadcast`` directly. + Args: obj: Object to broadcast to all process, usually a tensor or collection of tensors. src: The source rank of which the object will be broadcast from """ + rank_zero_deprecation( + "`Accelerator.broadcast` is deprecated in v1.5 and will be removed in v1.6. " + "Broadcast logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.broadcast(obj, src) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.all_gather`` directly. + Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) @@ -362,6 +383,10 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo Return: A tensor of shape (world_size, batch, ...) """ + rank_zero_deprecation( + "`Accelerator.all_gather` is deprecated in v1.5 and will be removed in v1.6. " + "All-gather logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index ef7e586654225..efeedb30c42a8 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -165,7 +165,7 @@ def on_load_checkpoint( def _check_time_remaining(self, trainer: "pl.Trainer") -> None: should_stop = self.time_elapsed() >= self._duration - should_stop = trainer.accelerator.broadcast(should_stop) + should_stop = trainer.training_type_plugin.broadcast(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop and self._verbose: elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c241863605e6e..31514b69d9ae4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -594,7 +594,7 @@ def all_gather( the output will also be a collection with tensors of this shape. """ group = group if group is not None else torch.distributed.group.WORLD - all_gather = self.trainer.accelerator.all_gather + all_gather = self.trainer.training_type_plugin.all_gather data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 675b5bc953503..6caaea8632354 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union import torch from torch import Tensor @@ -25,10 +25,9 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT -TBroadcast = TypeVar("T") - class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- @@ -90,26 +89,47 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce - *args: plugin-specific positional arguments - **kwargs: plugin-specific keyword arguments + group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. """ @abstractmethod def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other.""" + """Synchronizes all processes which blocks processes until the whole group enters this function. + + Args: + name: an optional name to pass into barrier. + """ @abstractmethod - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - """Broadcasts an object to all processes.""" + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes. + + Args: + obj: the object to broadcast + src: source rank + """ @abstractmethod def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" + """Perform an all_gather on all processes. + + Args: + tensor: the tensor to all_gather + group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + """ def reduce_boolean_decision(self, decision: bool) -> bool: """Reduce the early stopping decision across all processes.""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 2fe9118ac5400..e46d7b1b1e7e8 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -525,7 +525,7 @@ def request_dataloader( dataloader = self.call_hook(hook, pl_module=model) if isinstance(dataloader, tuple): dataloader = list(dataloader) - self.accelerator.barrier("get_dataloaders") + self.training_type_plugin.barrier("get_dataloaders") return dataloader @staticmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 644d11bb75bd2..5e7ed9785dc33 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -965,7 +965,7 @@ def _load_checkpoint_weights(self): # only one process running at this point for TPUs, as spawn isn't triggered yet # todo: move this logic internally within the barrier. if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() + self.training_type_plugin.barrier() rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) @@ -1148,7 +1148,7 @@ def run_stage(self): def _pre_training_routine(self): # wait for all to join if on distributed - self.accelerator.barrier("setup_training") + self.training_type_plugin.barrier("setup_training") # register signals self.signal_connector.register_signal_handlers() @@ -1289,13 +1289,13 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn - self.accelerator.barrier("pre_setup") + self.training_type_plugin.barrier("pre_setup") if self.datamodule is not None: self.datamodule.setup(stage=fn) self.call_hook("setup", stage=fn) - self.accelerator.barrier("post_setup") + self.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: with self.accelerator.model_sharded_context(): @@ -1604,7 +1604,7 @@ def log_dir(self) -> Optional[str]: else: dirpath = self.logger.save_dir - dirpath = self.accelerator.broadcast(dirpath) + dirpath = self.training_type_plugin.broadcast(dirpath) return dirpath @property diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 4028f4108feab..f580d9b89f7f8 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -15,6 +15,7 @@ from unittest.mock import call, Mock import pytest +import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -327,3 +328,22 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import(): _soft_unimport_module("pytorch_lightning.utilities.device_dtype_mixin") with pytest.deprecated_call(match="will be removed in v1.6"): from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401 + + +def test_v1_7_0_deprecated_accelerator_collective(): + from pytorch_lightning.plugins.precision import PrecisionPlugin + from pytorch_lightning.plugins.training_type import SingleDevicePlugin + + plugin = SingleDevicePlugin(torch.device("cpu")) + from pytorch_lightning.accelerators.accelerator import Accelerator + + accelerator = Accelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.barrier() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.broadcast(1) + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.all_gather(tensor) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index bd13275e9e5d1..03cc0e1ff7beb 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -57,11 +57,11 @@ def test_ddp_with_2_gpus(): class BarrierModel(BoringModel): def setup(self, stage=None): assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel) - self.trainer.accelerator.barrier("barrier before model is wrapped") + self.trainer.training_type_plugin.barrier("barrier before model is wrapped") def on_train_start(self): assert isinstance(self.trainer.accelerator.model, DistributedDataParallel) - self.trainer.accelerator.barrier("barrier after model is wrapped") + self.trainer.training_type_plugin.barrier("barrier after model is wrapped") @RunIf(min_gpus=4, special=True) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 68907096fe3d1..889a200b8a58f 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -830,9 +830,9 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat def _assert_save_model_is_equal(model, tmpdir, trainer): checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.accelerator.barrier() + trainer.training_type_plugin.barrier() # carry out the check only on rank 0 if trainer.is_global_zero: diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index a04e56b7aabad..c60c85253abf4 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -31,9 +31,9 @@ def test_deepspeed_collate_checkpoint(tmpdir): ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.accelerator.barrier() + trainer.training_type_plugin.barrier() if trainer.is_global_zero: # ensure function call works output_path = os.path.join(tmpdir, "single_model.pt")