Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import platform
import time
from typing import Type
from typing import Type, Union

import pytest
import torch
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def training_step(self, args):
with self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)

def post_training_step(self):
self.training_type_plugin.post_training_step()

def validation_step(self, args):
"""The actual validation step.

Expand Down Expand Up @@ -251,13 +254,13 @@ def backward(
opt_idx: the index of the optimizer
should_accumulate: whether to accumulate gradients
"""
self.training_type_plugin.pre_backward(closure_loss, optimizer, opt_idx)

output = self.precision_plugin.backward(
self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

# TODO: this is a hack, find a better solution for this (hook?)
if isinstance(self.training_type_plugin, HorovodPlugin):
optimizer.synchronize()
self.training_type_plugin.post_backward(closure_loss, optimizer, opt_idx)

return output

Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@

from torch.optim.optimizer import Optimizer

from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm


def is_lightning_optimizer(optimizer):
return isinstance(optimizer, LightningOptimizer)
Expand Down Expand Up @@ -62,6 +59,7 @@ def __init__(self,
self._trainer = None
self._accumulate_grad_batches = accumulate_grad_batches
self._optimizer_idx = None
self._total_optimizer_step_calls = 0

@property
def optimizer(self):
Expand Down Expand Up @@ -265,10 +263,11 @@ def dis_closure():

if make_optimizer_step:
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
self._total_optimizer_step_calls += 1
else:
# make sure to call optimizer_closure when accumulating
with self._trainer.profiler.profile(f"closure_{self._optimizer_idx}"):
with self._trainer.train_loop.block_ddp_sync_behaviour():
with self._trainer.train_loop.block_ddp_sync_behaviour(True):
closure()

def __repr__(self):
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def forward(self, *inputs, **kwargs):

if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)

# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in ``LightningModule.manual_backward``
# `require_backward_grad_sync` will be reset
# ddp_plugin ``post_training_step`` hook
if not self.module.automatic_optimization:
self.module.trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")
elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
Expand All @@ -55,7 +62,6 @@ def forward(self, *inputs, **kwargs):
warn_if_output_is_none(output, "validation_step")
else:
output = self.module.predict(*inputs, **kwargs)

return output


Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE

LightningShardedDataParallel = None
Expand Down
26 changes: 25 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pytorch_lightning.overrides.distributed import prepare_for_backward
import subprocess
import sys
from time import sleep
Expand All @@ -21,12 +22,14 @@
import torch
import torch.distributed as torch_distrib
from torch.nn.parallel.distributed import DistributedDataParallel

from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_THAN_1_7_0
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
Expand Down Expand Up @@ -177,7 +180,19 @@ def set_world_ranks(self):
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes

def pre_configure_ddp(self):
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()``` breaking manual_optimization
if _PYTORCH_GREATER_EQUAL_THAN_1_7_0 and not self.lightning_module.automatic_optimization:
rank_zero_warn(
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
"to properly work with DDP."
)
self._ddp_kwargs["find_unused_parameters"] = True

def configure_ddp(self):

self.pre_configure_ddp()

self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
Expand Down Expand Up @@ -253,6 +268,11 @@ def barrier(self, *args, **kwargs):
def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)

def model_to_device(self):
if self.root_device.type == "cuda":
torch.cuda.set_device(self.root_device)
Expand All @@ -274,3 +294,7 @@ def test_step(self, *args, **kwargs):

def predict(self, *args, **kwargs):
return self.model(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
23 changes: 23 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
import os
import re
from pytorch_lightning.overrides.distributed import prepare_for_backward
from typing import Any, Dict, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.distributed.dist import LightningDistributed
Expand All @@ -27,6 +29,7 @@
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_THAN_1_7_0
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
Expand Down Expand Up @@ -159,7 +162,18 @@ def post_training(self):
# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)

def pre_configure_ddp(self):
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()``` breaking manual_optimization
if _PYTORCH_GREATER_EQUAL_THAN_1_7_0 and not self.lightning_module.automatic_optimization:
rank_zero_warn(
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
"to properly work with DDP."
)
self._ddp_kwargs["find_unused_parameters"] = True

def configure_ddp(self):

self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
Expand Down Expand Up @@ -225,6 +239,11 @@ def model_to_device(self):
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if isinstance(output, torch.Tensor):
output = sync_ddp_if_available(output, group, reduce_op)
Expand All @@ -241,3 +260,7 @@ def test_step(self, *args, **kwargs):

def predict(self, *args, **kwargs):
return self.model(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, List, Optional, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import _LRScheduler, Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand Down Expand Up @@ -116,6 +116,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
return obj

def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
optimizer.synchronize()

def model_to_device(self):
if self.on_gpu:
torch.cuda.set_device(self.root_device)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def block_backward_sync(self):
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, (LightningDistributedDataParallel, DistributedDataParallel)):
yield self.model.no_sync()
if isinstance(self.model, DistributedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Optional, Sequence, TYPE_CHECKING, Union

import torch

from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
Expand Down Expand Up @@ -69,6 +69,12 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
"""Reduce the early stopping decision across all possibly spawned processes"""
return should_stop

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""

def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
"""Run after precision plugin executes backward"""

@property
def model(self) -> torch.nn.Module:
"""Returns the potentially wrapped LightningModule"""
Expand Down Expand Up @@ -107,6 +113,9 @@ def start_testing(self, trainer: 'Trainer') -> None:
def training_step(self, *args, **kwargs):
return self.lightning_module.training_step(*args, **kwargs)

def post_training_step(self):
pass

def validation_step(self, *args, **kwargs):
return self.lightning_module.validation_step(*args, **kwargs)

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
from argparse import ArgumentParser, Namespace
from typing import Any, cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer to automate the training."""
import os
import warnings
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import torch

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -282,6 +281,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
model_ref._results = Result()
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.accelerator_backend.post_training_step()

self.trainer.logger_connector.cache_logged_metrics()

self._check_training_step_output(training_step_output)
Expand Down Expand Up @@ -689,7 +690,7 @@ def train_step_and_backward_closure():
return result

@contextmanager
def block_ddp_sync_behaviour(self):
def block_ddp_sync_behaviour(self, should_block_sync: bool = False):
"""
automatic_optimization = True
Blocks ddp sync gradients behaviour on backwards pass.
Expand All @@ -703,8 +704,12 @@ def block_ddp_sync_behaviour(self):
context manager with sync behaviour off

"""
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and self.automatic_optimization:
yield self.trainer.training_type_plugin.block_backward_sync()
if (
isinstance(self.trainer.training_type_plugin, ParallelPlugin)
and (self.automatic_optimization or should_block_sync)
):
with self.trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None

Expand Down Expand Up @@ -745,7 +750,8 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
self._curr_step_result = result

if result is None:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
if self.automatic_optimization:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
return None

if self.trainer.train_loop.automatic_optimization:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_PYTORCH_GREATER_EQUAL_THAN_1_7_0,
_PYTORCH_PRUNE_AVAILABLE,
_RPC_AVAILABLE,
_TORCHTEXT_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ def _module_available(module_path: str) -> bool:
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_PYTORCH_PRUNE_AVAILABLE = _module_available('torch.nn.utils.prune')
_PYTORCH_GREATER_EQUAL_THAN_1_7_0 = LooseVersion(torch.__version__) >= LooseVersion("1.7.0")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
1 change: 0 additions & 1 deletion tests/accelerators/legacy/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.utilities import DistributedType
from tests.base.boring_model import BoringModel


Expand Down
Loading