From 55c3bfce3ecc97f0c74854d3ee34151c2baf7fd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 19:45:36 +0100 Subject: [PATCH 1/9] add wrapper --- pytorch_lightning/overrides/data_parallel.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 393138fff9248..72d7bbf31d896 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -23,6 +23,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel._functions import Gather +from pytorch_lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.warning_utils import WarningCache @@ -151,6 +152,28 @@ def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) +class LightningDistributedWrapper(torch.nn.Module): + + def __init__(self, lightning_module: LightningModule): + super().__init__() + self.module = lightning_module + + def forward(self, *inputs, **kwargs): + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + fx_called = 'training_step' + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + fx_called = 'test_step' + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + fx_called = 'validation_step' + + if output is None: + warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') + return output + + class LightningDistributedDataParallel(DistributedDataParallel): """ Override the forward call in lightning so it goes to training and validation step respectively From 13b69ee29f6e849eba2ed2b4335ac5087e44f094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 19:55:33 +0100 Subject: [PATCH 2/9] add squeeze --- pytorch_lightning/overrides/data_parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 72d7bbf31d896..ddf5e47400d49 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -168,9 +168,10 @@ def forward(self, *inputs, **kwargs): else: output = self.module.validation_step(*inputs[0], **kwargs[0]) fx_called = 'validation_step' - if output is None: - warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') + warn_missing_output(f'{fx_called} returned None. Did you forget to return an output?') + elif self.module.use_dp or self.module.use_ddp2: + auto_squeeze_dim_zeros(output) return output From 4e53fe4ce2b4b6999cf0b9baeb498177b1690c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 20:01:47 +0100 Subject: [PATCH 3/9] replace LightningDistributedDP --- pytorch_lightning/overrides/data_parallel.py | 138 +++++++++--------- pytorch_lightning/plugins/ddp_plugin.py | 23 +-- .../plugins/ddp_sequential_plugin.py | 3 +- 3 files changed, 83 insertions(+), 81 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index ddf5e47400d49..a05a0e8a30c5a 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -174,75 +174,75 @@ def forward(self, *inputs, **kwargs): auto_squeeze_dim_zeros(output) return output - -class LightningDistributedDataParallel(DistributedDataParallel): - """ - Override the forward call in lightning so it goes to training and validation step respectively - """ - PREPARE_FOR_BACKWARDS = True - - def parallel_apply(self, replicas, inputs, kwargs): - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) - - def forward(self, *inputs, **kwargs): # pragma: no-cover - self._sync_params() - self.reducer_reset_hooks() - fx_called: str = '' - - if self.device_ids: - - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - # -------------- - # LIGHTNING MOD - # -------------- - # normal - # output = self.module(*inputs[0], **kwargs[0]) - # lightning - if self.module.training: - output = self.module.training_step(*inputs[0], **kwargs[0]) - fx_called = 'training_step' - elif self.module.testing: - output = self.module.test_step(*inputs[0], **kwargs[0]) - fx_called = 'test_step' - else: - output = self.module.validation_step(*inputs[0], **kwargs[0]) - fx_called = 'validation_step' - else: - outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) - output = self.gather(outputs, self.output_device) - else: - # output = self.module(*inputs, **kwargs) - # normal lightning (ddp_cpu) - if self.module.training: - output = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: - output = self.module.test_step(*inputs, **kwargs) - else: - output = self.module.validation_step(*inputs, **kwargs) - - if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS: - self.reducer_prepare_for_backwards(output) - - if output is None: - warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') - return output - - def reducer_prepare_for_backwards(self, output): - self._reducer_prepared_for_backwards = True - if torch.is_grad_enabled(): - # We'll return the output object verbatim since it is a freeform - # object. We need to find any tensors in this object, though, - # because we need to figure out which parameters were used during - # this forward pass, to ensure we short circuit reduction for any - # unused parameters. Only if `find_unused_parameters` is set. - if self.find_unused_parameters: - self.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - self.reducer.prepare_for_backward([]) - - def reducer_reset_hooks(self): - self._reducer_prepared_for_backwards = False +# +# class LightningDistributedDataParallel(DistributedDataParallel): +# """ +# Override the forward call in lightning so it goes to training and validation step respectively +# """ +# PREPARE_FOR_BACKWARDS = True +# +# def parallel_apply(self, replicas, inputs, kwargs): +# return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) +# +# def forward(self, *inputs, **kwargs): # pragma: no-cover +# self._sync_params() +# self.reducer_reset_hooks() +# fx_called: str = '' +# +# if self.device_ids: +# +# inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) +# if len(self.device_ids) == 1: +# # -------------- +# # LIGHTNING MOD +# # -------------- +# # normal +# # output = self.module(*inputs[0], **kwargs[0]) +# # lightning +# if self.module.training: +# output = self.module.training_step(*inputs[0], **kwargs[0]) +# fx_called = 'training_step' +# elif self.module.testing: +# output = self.module.test_step(*inputs[0], **kwargs[0]) +# fx_called = 'test_step' +# else: +# output = self.module.validation_step(*inputs[0], **kwargs[0]) +# fx_called = 'validation_step' +# else: +# outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) +# output = self.gather(outputs, self.output_device) +# else: +# # output = self.module(*inputs, **kwargs) +# # normal lightning (ddp_cpu) +# if self.module.training: +# output = self.module.training_step(*inputs, **kwargs) +# elif self.module.testing: +# output = self.module.test_step(*inputs, **kwargs) +# else: +# output = self.module.validation_step(*inputs, **kwargs) +# +# if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS: +# self.reducer_prepare_for_backwards(output) +# +# if output is None: +# warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') +# return output +# +# def reducer_prepare_for_backwards(self, output): +# self._reducer_prepared_for_backwards = True +# if torch.is_grad_enabled(): +# # We'll return the output object verbatim since it is a freeform +# # object. We need to find any tensors in this object, though, +# # because we need to figure out which parameters were used during +# # this forward pass, to ensure we short circuit reduction for any +# # unused parameters. Only if `find_unused_parameters` is set. +# if self.find_unused_parameters: +# self.reducer.prepare_for_backward(list(_find_tensors(output))) +# else: +# self.reducer.prepare_for_backward([]) +# +# def reducer_reset_hooks(self): +# self._reducer_prepared_for_backwards = False def warn_missing_output(fx_called): diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 281074cb37813..3dcf172636e10 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Union import torch.distributed as torch_distrib +from torch.nn.parallel import DistributedDataParallel from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.overrides.data_parallel import LightningDistributedWrapper from pytorch_lightning.plugins.plugin import LightningPlugin @@ -35,7 +36,7 @@ def __init__(self, **kwargs): def configure_ddp( self, model: LightningModule, device_ids: List[int] - ) -> LightningDistributedDataParallel: + ) -> DistributedDataParallel: """ Pass through all customizations from constructor to `LightningDistributedDataParallel`. Override to define a custom DDP implementation. @@ -63,8 +64,8 @@ def configure_ddp(self, model, device_ids): self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True ) - model = LightningDistributedDataParallel( - model, + model = DistributedDataParallel( + module=LightningDistributedWrapper(model), device_ids=device_ids, **self._ddp_kwargs, ) @@ -120,7 +121,7 @@ def on_after_setup_optimizers(self, trainer): def get_model_from_plugin( self, - model: Union[LightningDistributedDataParallel, LightningModule] + model: Union[DistributedDataParallel, LightningModule] ) -> LightningModule: """ Override to modify returning base :class:`LightningModule` @@ -136,12 +137,14 @@ def get_model_from_plugin( Returns: Reference :class:`LightningModule` within parallel wrapper. """ - if isinstance(model, LightningDistributedDataParallel): - return model.module + if isinstance(model, DistributedDataParallel): + model = model.module + if isinstance(model, LightningDistributedWrapper): + model = model.module return model @contextmanager - def block_backward_sync(self, model: LightningDistributedDataParallel): + def block_backward_sync(self, model: DistributedDataParallel): """ Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead @@ -149,10 +152,10 @@ def block_backward_sync(self, model: LightningDistributedDataParallel): """ yield model.no_sync() - def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): + def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): model.reducer_prepare_for_backwards(output) - def on_after_manual_backward(self, model: LightningDistributedDataParallel): + def on_after_manual_backward(self, model: DistributedDataParallel): model.reducer_reset_hooks() def distributed_sampler_kwargs(self, distributed_sampler_kwargs): diff --git a/pytorch_lightning/plugins/ddp_sequential_plugin.py b/pytorch_lightning/plugins/ddp_sequential_plugin.py index f6ed45735d5a6..a6366b4c93493 100644 --- a/pytorch_lightning/plugins/ddp_sequential_plugin.py +++ b/pytorch_lightning/plugins/ddp_sequential_plugin.py @@ -21,7 +21,6 @@ from pytorch_lightning import LightningModule from pytorch_lightning import _logger as log -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -137,7 +136,7 @@ def init_ddp_connection( self._infer_model_balance(trainer) self._assert_valid_model_balance(trainer) - def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): + def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): pass def _infer_model_balance(self, trainer): From b3f20290b4cffbb608cab467a42a9b22f18f6afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 20:11:11 +0100 Subject: [PATCH 4/9] update import --- pytorch_lightning/overrides/data_parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index a05a0e8a30c5a..905e1e4a97cec 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -20,10 +20,9 @@ import torch from torch.cuda._utils import _get_device_index from torch.nn import DataParallel -from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel._functions import Gather -from pytorch_lightning import LightningModule +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.warning_utils import WarningCache From 727360355de0fa006e7d7206a6fbd9c1fdddbfcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 20:16:47 +0100 Subject: [PATCH 5/9] module access --- pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/training_loop.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 2f9e6f05d293e..2c368248feccc 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -182,7 +182,7 @@ def progress_bar_callback(self): @property def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ - ref_model = self.model if not self.data_parallel else self.model.module + ref_model = self.get_model() ref_model = cast(LightningModule, ref_model) return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b9dd2e2e195c0..0f84be3927652 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -124,9 +124,7 @@ def setup_training(self, model: LightningModule): # -------------------------- # Setup?? # -------------------------- - ref_model = model - if self.trainer.data_parallel: - ref_model = model.module + ref_model = self.trainer.get_model() # set the ranks and devices self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank From ef34dc17864ef1756c1e327755d009427844bce7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 20:18:00 +0100 Subject: [PATCH 6/9] inputs --- pytorch_lightning/overrides/data_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 905e1e4a97cec..5384ecda1c126 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -159,13 +159,13 @@ def __init__(self, lightning_module: LightningModule): def forward(self, *inputs, **kwargs): if self.module.training: - output = self.module.training_step(*inputs[0], **kwargs[0]) + output = self.module.training_step(*inputs, **kwargs) fx_called = 'training_step' elif self.module.testing: - output = self.module.test_step(*inputs[0], **kwargs[0]) + output = self.module.test_step(*inputs, **kwargs) fx_called = 'test_step' else: - output = self.module.validation_step(*inputs[0], **kwargs[0]) + output = self.module.validation_step(*inputs, **kwargs) fx_called = 'validation_step' if output is None: warn_missing_output(f'{fx_called} returned None. Did you forget to return an output?') From 9b3cf2effc41ba52e2a2a0168b4c09e0404be7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Dec 2020 20:56:32 +0100 Subject: [PATCH 7/9] refactor warning --- pytorch_lightning/overrides/data_parallel.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 5384ecda1c126..425aa24ef9d67 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -21,6 +21,7 @@ from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel._functions import Gather +from typing import Any from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result @@ -160,17 +161,13 @@ def __init__(self, lightning_module: LightningModule): def forward(self, *inputs, **kwargs): if self.module.training: output = self.module.training_step(*inputs, **kwargs) - fx_called = 'training_step' + warn_if_output_is_none(output, "training_step") elif self.module.testing: output = self.module.test_step(*inputs, **kwargs) - fx_called = 'test_step' + warn_if_output_is_none(output, "test_step") else: output = self.module.validation_step(*inputs, **kwargs) - fx_called = 'validation_step' - if output is None: - warn_missing_output(f'{fx_called} returned None. Did you forget to return an output?') - elif self.module.use_dp or self.module.use_ddp2: - auto_squeeze_dim_zeros(output) + warn_if_output_is_none(output, "validation_step") return output # @@ -244,6 +241,11 @@ def forward(self, *inputs, **kwargs): # self._reducer_prepared_for_backwards = False +def warn_if_output_is_none(output: Any, method_name: str) -> None: + if output is None: + warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') + + def warn_missing_output(fx_called): if fx_called == 'training_step': warning_cache.warn("Your training_step returned None. Make sure that was your intention!") From 5106ee85c979e7ca119e5a8ec484faa8b6db7ba6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 8 Jan 2021 10:06:09 +0000 Subject: [PATCH 8/9] update --- pytorch_lightning/accelerators/accelerator.py | 4 ---- pytorch_lightning/overrides/data_parallel.py | 20 +++++++++++++++++++ pytorch_lightning/plugins/ddp_plugin.py | 7 ++----- .../plugins/ddp_sequential_plugin.py | 6 +++--- pytorch_lightning/plugins/sharded_plugin.py | 5 +---- 5 files changed, 26 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9b56119a04c3e..1b3ae6f23058a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -104,10 +104,6 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # once backward has been applied, release graph closure_loss = closure_loss.detach() - - if not automatic_optimization and self.ddp_plugin is not None: - # Manually prepare for reduce as user calling backwards manually - self.ddp_plugin.on_after_manual_backward(self.trainer.model) return closure_loss def clip_gradients(self, optimizer, clip_val=None): diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 5260c56f4d6bc..e176a2b01956a 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -19,11 +19,13 @@ from typing import Optional import torch +from torch.nn import Module from torch import Tensor from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel._functions import Gather from typing import Any +from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result @@ -168,6 +170,24 @@ def forward(self, *inputs, **kwargs): warn_if_output_is_none(output, "validation_step") return output +# In manual_optimization, we need to call reducer prepare_for_backward. +# TODO: Keep track of Pytorch DDP and update if there is a change +# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L692 +def prepare_for_backward(model: DistributedDataParallel, output: Any): + if torch.is_grad_enabled() and model.require_backward_grad_sync: + model.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if model.find_unused_parameters: + model.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + model.reducer.prepare_for_backward([]) + else: + model.require_forward_param_sync = False + # # class LightningDistributedDataParallel(DistributedDataParallel): # """ diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 0b6e95b64a914..95469e1caa6dc 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -21,7 +21,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedWrapper +from pytorch_lightning.overrides.data_parallel import LightningDistributedWrapper, prepare_for_backward from pytorch_lightning.plugins.plugin import LightningPlugin @@ -167,10 +167,7 @@ def block_backward_sync(self, model: DistributedDataParallel): yield model.no_sync() def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): - model.reducer_prepare_for_backwards(output) - - def on_after_manual_backward(self, model: DistributedDataParallel): - model.reducer_reset_hooks() + prepare_for_backward(model, output) def distributed_sampler_kwargs(self, distributed_sampler_kwargs): return distributed_sampler_kwargs diff --git a/pytorch_lightning/plugins/ddp_sequential_plugin.py b/pytorch_lightning/plugins/ddp_sequential_plugin.py index 2692a09dc58bc..82250d1ed9fdd 100644 --- a/pytorch_lightning/plugins/ddp_sequential_plugin.py +++ b/pytorch_lightning/plugins/ddp_sequential_plugin.py @@ -266,10 +266,10 @@ def _check_arguments(self, trainer): def configure_ddp( self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: - ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids) + model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids) # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel - ddp_plugin.PREPARE_FOR_BACKWARDS = False - return ddp_plugin + model.require_backward_grad_sync = False + return model @rank_zero_only def rpc_save_model( diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 83395d4826a3a..95b243c566da2 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -96,7 +96,4 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list: return [] def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any): - pass - - def on_after_manual_backward(self, model: 'LightningShardedDataParallel'): - pass + pass \ No newline at end of file From c839b0d42bfabba3087b0d99a0c08675dbb89df9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 8 Jan 2021 11:20:26 +0100 Subject: [PATCH 9/9] resolve flake8 --- pytorch_lightning/overrides/data_parallel.py | 13 ++++++------- pytorch_lightning/plugins/sharded_plugin.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index e176a2b01956a..bd8a77a497968 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -16,16 +16,14 @@ import threading from collections.abc import Iterable, Mapping from itertools import chain -from typing import Optional +from typing import Any, Optional import torch -from torch.nn import Module from torch import Tensor from torch.cuda._utils import _get_device_index -from torch.nn import DataParallel -from torch.nn.parallel._functions import Gather -from typing import Any +from torch.nn import DataParallel, Module from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel._functions import Gather from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result @@ -170,7 +168,8 @@ def forward(self, *inputs, **kwargs): warn_if_output_is_none(output, "validation_step") return output -# In manual_optimization, we need to call reducer prepare_for_backward. + +# In manual_optimization, we need to call reducer prepare_for_backward. # TODO: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L692 def prepare_for_backward(model: DistributedDataParallel, output: Any): @@ -186,7 +185,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any): else: model.reducer.prepare_for_backward([]) else: - model.require_forward_param_sync = False + model.require_forward_param_sync = False # # class LightningDistributedDataParallel(DistributedDataParallel): diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 95b243c566da2..510a44ad1bddf 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union, Any +from typing import Any, List, Optional, Union from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer @@ -96,4 +96,4 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list: return [] def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any): - pass \ No newline at end of file + pass