Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# once backward has been applied, release graph
closure_loss = closure_loss.detach()

if not automatic_optimization and self.ddp_plugin is not None:
# Manually prepare for reduce as user calling backwards manually
self.ddp_plugin.on_after_manual_backward(self.trainer.model)
return closure_loss

def clip_gradients(self, optimizer, clip_val=None):
Expand Down
170 changes: 107 additions & 63 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import threading
from collections.abc import Iterable, Mapping
from itertools import chain
from typing import Optional
from typing import Any, Optional

import torch
from torch import Tensor
Expand All @@ -25,6 +25,7 @@
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -149,74 +150,117 @@ def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])


class LightningDistributedDataParallel(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
PREPARE_FOR_BACKWARDS = True
class LightningDistributedWrapper(torch.nn.Module):

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
def __init__(self, lightning_module: LightningModule):
super().__init__()
self.module = lightning_module

def forward(self, *inputs, **kwargs): # pragma: no-cover
self._sync_params()
self.reducer_reset_hooks()
fx_called: str = ''

if self.device_ids:

inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# --------------
# LIGHTNING MOD
# --------------
# normal
# output = self.module(*inputs[0], **kwargs[0])
# lightning
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
fx_called = 'training_step'
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
fx_called = 'test_step'
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
fx_called = 'validation_step'
else:
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
def forward(self, *inputs, **kwargs):
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
else:
# output = self.module(*inputs, **kwargs)
# normal lightning (ddp_cpu)
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
else:
output = self.module.validation_step(*inputs, **kwargs)
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
return output

if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
self.reducer_prepare_for_backwards(output)

if output is None:
warn_missing_output(f'{fx_called} returned None. Did you forget to return an output')
return output
# In manual_optimization, we need to call reducer prepare_for_backward.
# TODO: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L692
def prepare_for_backward(model: DistributedDataParallel, output: Any):
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if model.find_unused_parameters:
model.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False

#
# class LightningDistributedDataParallel(DistributedDataParallel):
# """
# Override the forward call in lightning so it goes to training and validation step respectively
# """
# PREPARE_FOR_BACKWARDS = True
#
# def parallel_apply(self, replicas, inputs, kwargs):
# return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
#
# def forward(self, *inputs, **kwargs): # pragma: no-cover
# self._sync_params()
# self.reducer_reset_hooks()
# fx_called: str = ''
#
# if self.device_ids:
#
# inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
# if len(self.device_ids) == 1:
# # --------------
# # LIGHTNING MOD
# # --------------
# # normal
# # output = self.module(*inputs[0], **kwargs[0])
# # lightning
# if self.module.training:
# output = self.module.training_step(*inputs[0], **kwargs[0])
# fx_called = 'training_step'
# elif self.module.testing:
# output = self.module.test_step(*inputs[0], **kwargs[0])
# fx_called = 'test_step'
# else:
# output = self.module.validation_step(*inputs[0], **kwargs[0])
# fx_called = 'validation_step'
# else:
# outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
# output = self.gather(outputs, self.output_device)
# else:
# # output = self.module(*inputs, **kwargs)
# # normal lightning (ddp_cpu)
# if self.module.training:
# output = self.module.training_step(*inputs, **kwargs)
# elif self.module.testing:
# output = self.module.test_step(*inputs, **kwargs)
# else:
# output = self.module.validation_step(*inputs, **kwargs)
#
# if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
# self.reducer_prepare_for_backwards(output)
#
# if output is None:
# warn_missing_output(f'{fx_called} returned None. Did you forget to return an output')
# return output
#
# def reducer_prepare_for_backwards(self, output):
# self._reducer_prepared_for_backwards = True
# if torch.is_grad_enabled():
# # We'll return the output object verbatim since it is a freeform
# # object. We need to find any tensors in this object, though,
# # because we need to figure out which parameters were used during
# # this forward pass, to ensure we short circuit reduction for any
# # unused parameters. Only if `find_unused_parameters` is set.
# if self.find_unused_parameters:
# self.reducer.prepare_for_backward(list(_find_tensors(output)))
# else:
# self.reducer.prepare_for_backward([])
#
# def reducer_reset_hooks(self):
# self._reducer_prepared_for_backwards = False


def reducer_prepare_for_backwards(self, output):
self._reducer_prepared_for_backwards = True
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])

def reducer_reset_hooks(self):
self._reducer_prepared_for_backwards = False
def warn_if_output_is_none(output: Any, method_name: str) -> None:
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')


def warn_missing_output(fx_called):
Expand Down
26 changes: 13 additions & 13 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from typing import Any, Dict, List, Union

import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.overrides.data_parallel import LightningDistributedWrapper, prepare_for_backward
from pytorch_lightning.plugins.plugin import LightningPlugin


Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(self, **kwargs):

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> LightningDistributedDataParallel:
) -> DistributedDataParallel:
"""
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
Override to define a custom DDP implementation.
Expand Down Expand Up @@ -76,8 +77,8 @@ def configure_ddp(self, model, device_ids):
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
)
model = LightningDistributedDataParallel(
model,
model = DistributedDataParallel(
module=LightningDistributedWrapper(model),
device_ids=device_ids,
**self._ddp_kwargs,
)
Expand Down Expand Up @@ -134,7 +135,7 @@ def on_after_setup_optimizers(self, trainer):

def get_model_from_plugin(
self,
model: Union[LightningDistributedDataParallel, LightningModule]
model: Union[DistributedDataParallel, LightningModule]
) -> LightningModule:
"""
Override to modify returning base :class:`LightningModule`
Expand All @@ -150,24 +151,23 @@ def get_model_from_plugin(
Returns: Reference :class:`LightningModule` within parallel wrapper.

"""
if isinstance(model, LightningDistributedDataParallel):
return model.module
if isinstance(model, DistributedDataParallel):
model = model.module
if isinstance(model, LightningDistributedWrapper):
model = model.module
return model

@contextmanager
def block_backward_sync(self, model: LightningDistributedDataParallel):
def block_backward_sync(self, model: DistributedDataParallel):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
yield model.no_sync()

def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
model.reducer_prepare_for_backwards(output)

def on_after_manual_backward(self, model: LightningDistributedDataParallel):
model.reducer_reset_hooks()
def on_before_manual_backward(self, model: DistributedDataParallel, output: Any):
prepare_for_backward(model, output)

def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
return distributed_sampler_kwargs
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -137,7 +136,7 @@ def init_ddp_connection(
self._infer_model_balance(trainer)
self._assert_valid_model_balance(trainer)

def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
def on_before_manual_backward(self, model: DistributedDataParallel, output: Any):
pass

def _infer_model_balance(self, trainer):
Expand Down Expand Up @@ -267,10 +266,10 @@ def _check_arguments(self, trainer):
def configure_ddp(
self,
model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
ddp_plugin.PREPARE_FOR_BACKWARDS = False
return ddp_plugin
model.require_backward_grad_sync = False
return model

@rank_zero_only
def rpc_save_model(
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
Expand Down Expand Up @@ -97,6 +97,3 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list:

def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
pass

def on_after_manual_backward(self, model: 'LightningShardedDataParallel'):
pass
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def progress_bar_callback(self):
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
ref_model = self.get_model()
ref_model = cast(LightningModule, ref_model)
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def setup_training(self, model: LightningModule):
# --------------------------
# Setup??
# --------------------------
ref_model = model
if self.trainer.data_parallel:
ref_model = model.module
ref_model = self.trainer.get_model()

# set the ranks and devices
self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank
Expand Down