From 8ffdb29ba4df0ab20369fd66c4f3a91b88b3de35 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 24 Sep 2021 11:51:06 -0700 Subject: [PATCH 01/12] Refactor collective functions, call training_type_plugin directly --- pytorch_lightning/callbacks/timer.py | 2 +- pytorch_lightning/core/lightning.py | 2 +- .../training_type/training_type_plugin.py | 42 +++++++++++++++---- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/trainer.py | 10 ++--- tests/plugins/test_ddp_plugin.py | 4 +- tests/plugins/test_deepspeed_plugin.py | 4 +- .../test_deepspeed_collate_checkpoint.py | 4 +- 8 files changed, 49 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index ef7e586654225..efeedb30c42a8 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -165,7 +165,7 @@ def on_load_checkpoint( def _check_time_remaining(self, trainer: "pl.Trainer") -> None: should_stop = self.time_elapsed() >= self._duration - should_stop = trainer.accelerator.broadcast(should_stop) + should_stop = trainer.training_type_plugin.broadcast(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop and self._verbose: elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c241863605e6e..31514b69d9ae4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -594,7 +594,7 @@ def all_gather( the output will also be a collection with tensors of this shape. """ group = group if group is not None else torch.distributed.group.WORLD - all_gather = self.trainer.accelerator.all_gather + all_gather = self.trainer.training_type_plugin.all_gather data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 13d6f93f5fb97..7212d26136cfc 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, TypeVar, Union import torch from torch import Tensor @@ -25,6 +25,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -91,26 +92,53 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce + group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. *args: plugin-specific positional arguments **kwargs: plugin-specific keyword arguments """ @abstractmethod def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other.""" + """Synchronizes all processes which blocks processes until the whole group enters this function. + + Args: + name: a str pass into barrier. Only torch xla respect this param + """ @abstractmethod - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - """Broadcasts an object to all processes.""" + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes. + + Args: + obj: the object to broadcast + src: source rank. + """ @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" + def all_gather( + self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[List[torch.Tensor], torch.Tensor]: + """Perform a all_gather on all processes. + + Args: + tensor: the tensor to all_gather + group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Returns: a tensor (torch distributed) or a list of tensor (horovod) + """ def reduce_boolean_decision(self, decision: bool) -> bool: """Reduce the early stopping decision across all processes.""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 969404e68c498..61160093b7d08 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -531,7 +531,7 @@ def request_dataloader( dataloader = self.call_hook(hook, pl_module=model) if isinstance(dataloader, tuple): dataloader = list(dataloader) - self.accelerator.barrier("get_dataloaders") + self.accelerator.training_type_plugin.barrier("get_dataloaders") return dataloader @staticmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 581ff11554cb3..424b2db520200 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -958,7 +958,7 @@ def _load_checkpoint_weights(self): # only one process running at this point for TPUs, as spawn isn't triggered yet # todo: move this logic internally within the barrier. if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() + self.accelerator.training_type_plugin.barrier() rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) @@ -1141,7 +1141,7 @@ def run_stage(self): def _pre_training_routine(self): # wait for all to join if on distributed - self.accelerator.barrier("setup_training") + self.accelerator.training_type_plugin.barrier("setup_training") # register signals self.signal_connector.register_signal_handlers() @@ -1282,13 +1282,13 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn - self.accelerator.barrier("pre_setup") + self.accelerator.training_type_plugin.barrier("pre_setup") if self.datamodule is not None: self.datamodule.setup(stage=fn) self.call_hook("setup", stage=fn) - self.accelerator.barrier("post_setup") + self.accelerator.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: # Call configure sharded model hook if accelerator requests. In some cases @@ -1606,7 +1606,7 @@ def log_dir(self) -> Optional[str]: else: dirpath = self.logger.save_dir - dirpath = self.accelerator.broadcast(dirpath) + dirpath = self.accelerator.training_type_plugin.broadcast(dirpath) return dirpath @property diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index bd13275e9e5d1..03cc0e1ff7beb 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -57,11 +57,11 @@ def test_ddp_with_2_gpus(): class BarrierModel(BoringModel): def setup(self, stage=None): assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel) - self.trainer.accelerator.barrier("barrier before model is wrapped") + self.trainer.training_type_plugin.barrier("barrier before model is wrapped") def on_train_start(self): assert isinstance(self.trainer.accelerator.model, DistributedDataParallel) - self.trainer.accelerator.barrier("barrier after model is wrapped") + self.trainer.training_type_plugin.barrier("barrier after model is wrapped") @RunIf(min_gpus=4, special=True) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c7ccaab3e72f4..2a86a663b81f2 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -830,9 +830,9 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat def _assert_save_model_is_equal(model, tmpdir, trainer): checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.accelerator.barrier() + trainer.training_type_plugin.barrier() # carry out the check only on rank 0 if trainer.is_global_zero: diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index a04e56b7aabad..c60c85253abf4 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -31,9 +31,9 @@ def test_deepspeed_collate_checkpoint(tmpdir): ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.accelerator.barrier() + trainer.training_type_plugin.barrier() if trainer.is_global_zero: # ensure function call works output_path = os.path.join(tmpdir, "single_model.pt") From 7141fb031bda2be38b588657c0bd07999e41bb74 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 24 Sep 2021 11:54:00 -0700 Subject: [PATCH 02/12] Refactor collective functions, call training_type_plugin directly --- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/trainer.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 61160093b7d08..3a7c3bdc03836 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -531,7 +531,7 @@ def request_dataloader( dataloader = self.call_hook(hook, pl_module=model) if isinstance(dataloader, tuple): dataloader = list(dataloader) - self.accelerator.training_type_plugin.barrier("get_dataloaders") + self.training_type_plugin.barrier("get_dataloaders") return dataloader @staticmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 424b2db520200..cd82af83d00cf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -958,7 +958,7 @@ def _load_checkpoint_weights(self): # only one process running at this point for TPUs, as spawn isn't triggered yet # todo: move this logic internally within the barrier. if not self._device_type == DeviceType.TPU: - self.accelerator.training_type_plugin.barrier() + self.training_type_plugin.barrier() rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) @@ -1141,7 +1141,7 @@ def run_stage(self): def _pre_training_routine(self): # wait for all to join if on distributed - self.accelerator.training_type_plugin.barrier("setup_training") + self.training_type_plugin.barrier("setup_training") # register signals self.signal_connector.register_signal_handlers() @@ -1282,13 +1282,13 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn - self.accelerator.training_type_plugin.barrier("pre_setup") + self.training_type_plugin.barrier("pre_setup") if self.datamodule is not None: self.datamodule.setup(stage=fn) self.call_hook("setup", stage=fn) - self.accelerator.training_type_plugin.barrier("post_setup") + self.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: # Call configure sharded model hook if accelerator requests. In some cases @@ -1606,7 +1606,7 @@ def log_dir(self) -> Optional[str]: else: dirpath = self.logger.save_dir - dirpath = self.accelerator.training_type_plugin.broadcast(dirpath) + dirpath = self.training_type_plugin.broadcast(dirpath) return dirpath @property From fb0d007c60d738f4267e139db2628f073cf938cc Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 24 Sep 2021 12:01:58 -0700 Subject: [PATCH 03/12] Refactor collective functions, call training_type_plugin directly --- CHANGELOG.md | 7 +++++++ pytorch_lightning/accelerators/accelerator.py | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d85d89169a928..5065ef3661ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -224,6 +224,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787)) +- Directly call TrainingTypePlugin collective APIs instead of going through the Accelerator ([#9426](https://github.com/PyTorchLightning/pytorch-lightning/pull/9426)) + + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` @@ -265,6 +269,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616)) +- Deprecate Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call training type plugin collective API directly ([#9426](https://github.com/PyTorchLightning/pytorch-lightning/pull/9426)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3036fd83ebf22..871c9930505c1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -338,12 +338,21 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.lightning_module_state_dict() def barrier(self, name: Optional[str] = None) -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.7. + Please call training_type_plugin.barrier directly + """ self.training_type_plugin.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.7. + Please call training_type_plugin.broadcast directly + Args: obj: Object to broadcast to all process, usually a tensor or collection of tensors. src: The source rank of which the object will be broadcast from @@ -353,6 +362,10 @@ def broadcast(self, obj: object, src: int = 0) -> object: def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.7. + Please call training_type_plugin.all_gather directly + Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) From b91d4c9926b5f3c1db38185f66791a7509ffa32e Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Fri, 24 Sep 2021 14:05:03 -0700 Subject: [PATCH 04/12] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- .../plugins/training_type/training_type_plugin.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5065ef3661ae3..bffe7ef6e2d3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -269,7 +269,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616)) -- Deprecate Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call training type plugin collective API directly ([#9426](https://github.com/PyTorchLightning/pytorch-lightning/pull/9426)) +- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call training type plugin collective API directly ([#9426](https://github.com/PyTorchLightning/pytorch-lightning/pull/9426)) ### Removed diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7212d26136cfc..5fb3a77778bd1 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -114,7 +114,7 @@ def barrier(self, name: Optional[str] = None) -> None: """Synchronizes all processes which blocks processes until the whole group enters this function. Args: - name: a str pass into barrier. Only torch xla respect this param + name: an optional name to pass into barrier. """ @abstractmethod @@ -123,7 +123,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: Args: obj: the object to broadcast - src: source rank. + src: source rank """ @abstractmethod @@ -137,7 +137,8 @@ def all_gather( group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op - Returns: a tensor (torch distributed) or a list of tensor (horovod) + Returns: + a tensor (torch distributed) or a list of tensor (horovod) """ def reduce_boolean_decision(self, decision: bool) -> bool: From 57c3efd86ffb8b9b8f6417dea8f42f4f1c54e820 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Sep 2021 21:06:11 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 5fb3a77778bd1..e48fe2a9459c4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -137,7 +137,7 @@ def all_gather( group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op - Returns: + Returns: a tensor (torch distributed) or a list of tensor (horovod) """ From ed53c0d63c5600bc999030240fe0918edf470dda Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Fri, 24 Sep 2021 14:09:09 -0700 Subject: [PATCH 06/12] Apply suggestions from code review Co-authored-by: ananthsub --- pytorch_lightning/accelerators/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 871c9930505c1..c580c93f6714f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -340,7 +340,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: def barrier(self, name: Optional[str] = None) -> None: """ .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.7. + This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.barrier directly """ self.training_type_plugin.barrier(name=name) @@ -350,7 +350,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: needed. .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.7. + This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.broadcast directly Args: @@ -363,7 +363,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo """Function to gather a tensor from several distributed processes. .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.7. + This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.all_gather directly Args: From 3d3e86e17d42bd84bbd06bd31c72efbc009b5eff Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 24 Sep 2021 15:48:47 -0700 Subject: [PATCH 07/12] Refactor collective functions, call training_type_plugin directly --- pytorch_lightning/accelerators/accelerator.py | 14 +++++++++++++- .../plugins/training_type/training_type_plugin.py | 8 ++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c580c93f6714f..a3e529a1a34e0 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT @@ -343,6 +343,10 @@ def barrier(self, name: Optional[str] = None) -> None: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.barrier directly """ + rank_zero_deprecation( + "This method is deprecated in v1.5 and will be removed in v1.6." + "barrier logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + ) self.training_type_plugin.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: @@ -357,6 +361,10 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj: Object to broadcast to all process, usually a tensor or collection of tensors. src: The source rank of which the object will be broadcast from """ + rank_zero_deprecation( + "This method is deprecated in v1.5 and will be removed in v1.6." + "Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + ) return self.training_type_plugin.broadcast(obj, src) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -374,6 +382,10 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo Return: A tensor of shape (world_size, batch, ...) """ + rank_zero_deprecation( + "This method is deprecated in v1.5 and will be removed in v1.6." + "all_gather logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + ) return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e48fe2a9459c4..9bf1c9dd6ad31 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union import torch from torch import Tensor @@ -28,8 +28,6 @@ from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT -TBroadcast = TypeVar("T") - class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- @@ -127,9 +125,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """ @abstractmethod - def all_gather( - self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[List[torch.Tensor], torch.Tensor]: + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """Perform a all_gather on all processes. Args: From a64ed22f3eb6032b440ca6b5f0ae65070b9e3919 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 24 Sep 2021 16:16:03 -0700 Subject: [PATCH 08/12] Refactor collective functions, call training_type_plugin directly --- .../training_type/training_type_plugin.py | 3 --- tests/deprecated_api/test_remove_1-6.py | 24 +++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9bf1c9dd6ad31..1b4b9e0b42fe7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -132,9 +132,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra tensor: the tensor to all_gather group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op - - Returns: - a tensor (torch distributed) or a list of tensor (horovod) """ def reduce_boolean_decision(self, decision: bool) -> bool: diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index fec29ed6b47f8..9f37aab416e52 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -13,6 +13,7 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.6.0.""" import pytest +import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -318,3 +319,26 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import(): _soft_unimport_module("pytorch_lightning.utilities.device_dtype_mixin") with pytest.deprecated_call(match="will be removed in v1.6"): from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401 + + +def test_v1_7_0_deprecated_accelerator_collective(): + from pytorch_lightning.plugins.precision import PrecisionPlugin + from pytorch_lightning.plugins.training_type import SingleDevicePlugin + + plugin = SingleDevicePlugin(torch.device("cpu")) + from pytorch_lightning.accelerators.accelerator import Accelerator + + accelerator = Accelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.barrier() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.broadcast(1) + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand( + 2, + 2, + requires_grad=True, + ) + accelerator.all_gather(tensor) From abf0833cf0c8036944062d28a6cbe1aa04735654 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 24 Sep 2021 19:56:31 -0700 Subject: [PATCH 09/12] Refactor collective functions, call training_type_plugin directly --- pytorch_lightning/accelerators/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a3e529a1a34e0..5925acbeca2ce 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -344,7 +344,7 @@ def barrier(self, name: Optional[str] = None) -> None: Please call training_type_plugin.barrier directly """ rank_zero_deprecation( - "This method is deprecated in v1.5 and will be removed in v1.6." + "Accelerator barrier is deprecated in v1.5 and will be removed in v1.6. " "barrier logic is implemented directly in the :class:`TrainingTypePlugin` implementations." ) self.training_type_plugin.barrier(name=name) @@ -362,7 +362,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: src: The source rank of which the object will be broadcast from """ rank_zero_deprecation( - "This method is deprecated in v1.5 and will be removed in v1.6." + "Accelerator broadcast is deprecated in v1.5 and will be removed in v1.6. " "Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations." ) return self.training_type_plugin.broadcast(obj, src) @@ -383,7 +383,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, batch, ...) """ rank_zero_deprecation( - "This method is deprecated in v1.5 and will be removed in v1.6." + "Accelerator all_gather is deprecated in v1.5 and will be removed in v1.6. " "all_gather logic is implemented directly in the :class:`TrainingTypePlugin` implementations." ) return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) From 8f059818d68b3db4b44f7d28372e6c2a6ade79ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 27 Sep 2021 12:21:06 +0200 Subject: [PATCH 10/12] update formatting of warning messages --- pytorch_lightning/accelerators/accelerator.py | 14 +++++++------- tests/deprecated_api/test_remove_1-6.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4d9c10cb622e5..f7964d5c20530 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -341,11 +341,11 @@ def barrier(self, name: Optional[str] = None) -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call training_type_plugin.barrier directly + Please call ``training_type_plugin.barrier`` directly """ rank_zero_deprecation( - "Accelerator barrier is deprecated in v1.5 and will be removed in v1.6. " - "barrier logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " + "Barrier logic is implemented directly in the `TrainingTypePlugin` implementations." ) self.training_type_plugin.barrier(name=name) @@ -362,8 +362,8 @@ def broadcast(self, obj: object, src: int = 0) -> object: src: The source rank of which the object will be broadcast from """ rank_zero_deprecation( - "Accelerator broadcast is deprecated in v1.5 and will be removed in v1.6. " - "Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + "`Accelerator.broadcast` is deprecated in v1.5 and will be removed in v1.6. " + "Broadcast logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.broadcast(obj, src) @@ -383,8 +383,8 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, batch, ...) """ rank_zero_deprecation( - "Accelerator all_gather is deprecated in v1.5 and will be removed in v1.6. " - "all_gather logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + "`Accelerator.all_gather` is deprecated in v1.5 and will be removed in v1.6. " + "All-gather logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 8582b052d4570..a83c288f1e414 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -348,6 +348,6 @@ def test_v1_7_0_deprecated_accelerator_collective(): tensor = torch.rand( 2, 2, - requires_grad=True, + requires_grad=True ) accelerator.all_gather(tensor) From b8c48e57806629888e6d0b44ceea9ebfc4bb2d4b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Sep 2021 10:22:38 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-6.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index a83c288f1e414..f580d9b89f7f8 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -345,9 +345,5 @@ def test_v1_7_0_deprecated_accelerator_collective(): accelerator.broadcast(1) with pytest.deprecated_call(match="will be removed in v1.6"): - tensor = torch.rand( - 2, - 2, - requires_grad=True - ) + tensor = torch.rand(2, 2, requires_grad=True) accelerator.all_gather(tensor) From ca29fc9161d897597749ce43e935847102873ec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 27 Sep 2021 12:56:23 +0200 Subject: [PATCH 12/12] Apply suggestions from code review Co-authored-by: Rohit Gupta --- CHANGELOG.md | 7 +++---- pytorch_lightning/accelerators/accelerator.py | 6 +++--- .../plugins/training_type/training_type_plugin.py | 4 +--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81fb4592b8779..73f6adabda3a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,10 +233,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787)) -- Directly call TrainingTypePlugin collective APIs instead of going through the Accelerator ([#9426](https://github.com/PyTorchLightning/pytorch-lightning/pull/9426)) +- Directly call `TrainingTypePlugin` collective APIs instead of going through the Accelerator ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677)) -- Use a unique filename to save temp ckpt in tuner ([#96827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682)) +- Use a unique filename to save temp ckpt in tuner ([#9682](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682)) ### Deprecated @@ -286,8 +286,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor in favor of adding the `StochasticWeightAveraging` callback directly to the list of callbacks ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989)) -- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call training type plugin collective API directly ([#9426](https://github.com/PyTorchLightning/pytorch-lightning/pull/9426)) - +- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677)) ### Removed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index dc5955d9ac14a..89e52758f5f90 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -342,7 +342,7 @@ def barrier(self, name: Optional[str] = None) -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.barrier`` directly + Please call ``training_type_plugin.barrier`` directly. """ rank_zero_deprecation( "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " @@ -356,7 +356,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call training_type_plugin.broadcast directly + Please call ``training_type_plugin.broadcast`` directly. Args: obj: Object to broadcast to all process, usually a tensor or collection of tensors. @@ -373,7 +373,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call training_type_plugin.all_gather directly + Please call ``training_type_plugin.all_gather`` directly. Args: tensor: tensor of shape (batch, ...) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 09ea07cb87c04..6caaea8632354 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -102,8 +102,6 @@ def reduce( group: the process group to reduce reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. - *args: plugin-specific positional arguments - **kwargs: plugin-specific keyword arguments """ @abstractmethod @@ -125,7 +123,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: @abstractmethod def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes. + """Perform an all_gather on all processes. Args: tensor: the tensor to all_gather