Skip to content
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,6 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:

def on_save(self, checkpoint):
return checkpoint

def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)
82 changes: 77 additions & 5 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from typing import Optional, Sequence

import torch

Expand All @@ -26,15 +27,21 @@
DataParallelPlugin,
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
DDPSpawnPlugin,
DDPSpawnShardedPlugin,
HorovodPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
RPCPlugin,
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
SingleTPUPlugin,
TPUHalfPrecisionPlugin,
TPUSpawnPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin,
TPUSpawnPlugin,
TrainingTypePlugin,
DDPShardedPlugin,
DDPSpawnShardedPlugin,
)
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
Expand Down Expand Up @@ -74,6 +81,7 @@ def __init__(
amp_type,
amp_level,
cluster_environment,
plugins,
):
# initialization
self._device_type = DeviceType.CPU
Expand All @@ -95,6 +103,11 @@ def __init__(
self.cluster_environment = cluster_environment
self.is_slurm_managing_tasks = False

self._precision_plugin: Optional[PrecisionPlugin] = None
self._training_type_plugin: Optional[TrainingTypePlugin] = None

self.handle_given_plugins(plugins)

# init the default rank if exists
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
# this way we only show it on rank 0
Expand Down Expand Up @@ -136,6 +149,56 @@ def __init__(

self.replace_sampler_ddp = replace_sampler_ddp

def handle_given_plugins(self, plugins: Optional[Sequence]):
if plugins is None:
return

if not isinstance(plugins, Sequence):
plugins = [plugins]

training_type = None
precision = None

for plug in plugins:
if isinstance(plug, TrainingTypePlugin):
if training_type is None:
training_type = plug
else:
raise MisconfigurationException(
'You can only specify one precision and one training type plugin. '
'Found more than 1 training type plugin'
)
elif isinstance(plug, PrecisionPlugin):
if precision is None:
precision = plug
else:
raise MisconfigurationException(
'You can only specify one precision and one training type plugin. '
'Found more than 1 precision plugin'
)
else:
raise MisconfigurationException(
f'Found invalid type for plugin {plug}. '
'Expected a precision or training type plugin.'
)

self._training_type_plugin = training_type
self._precision_plugin = precision

@property
def precision_plugin(self) -> PrecisionPlugin:
if self._precision_plugin is None:
self._precision_plugin = self.select_precision_plugin()

return self._precision_plugin

@property
def training_type_plugin(self) -> TrainingTypePlugin:
if self._training_type_plugin is None:
self._training_type_plugin = self.select_training_type_plugin()

return self._training_type_plugin

@property
def on_cpu(self):
return self._device_type == DeviceType.CPU
Expand Down Expand Up @@ -205,6 +268,9 @@ def select_precision_plugin(self):
if self.on_tpu:
return TPUHalfPrecisionPlugin()

if isinstance(self.training_type_plugin, RPCPlugin):
raise MisconfigurationException

if self.amp_type == "native":
if not _NATIVE_AMP_AVAILABLE:
rank_zero_warn(
Expand All @@ -215,7 +281,7 @@ def select_precision_plugin(self):
self.amp_type = "apex"
else:
log.info("Using native 16bit precision.")
if self.distributed_backend == "ddp_sharded" or self.distributed_backend == "ddp_sharded_spawn":
if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)):
return ShardedNativeMixedPrecisionPlugin()
self.amp_type = AMPType.NATIVE
return NativeMixedPrecisionPlugin()
Expand All @@ -227,7 +293,7 @@ def select_precision_plugin(self):
" Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
)
else:
if self.distributed_backend == "ddp_sharded" or self.distributed_backend == "ddp_sharded_spawn":
if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)):
raise MisconfigurationException(
"Sharded Plugin is not supported with Apex AMP, "
"please using native AMP for 16-bit precision."
Expand Down Expand Up @@ -289,6 +355,12 @@ def select_training_type_plugin(self):
def select_accelerator(self):
if isinstance(self.distributed_backend, Accelerator):
# custom accelerator from user
if self._precision_plugin is not None or self._training_type_plugin is not None:
# plugins also specified by user
rank_zero_warn(
'Specified Precision and TrainingType Plugins will be ignored, '
'since an Accelerator instance was provided'
)
return self.distributed_backend

if self.on_gpu:
Expand All @@ -299,8 +371,8 @@ def select_accelerator(self):
acc_cls = CPUAccelerator

return acc_cls(
precision_plugin=self.select_precision_plugin(),
training_type_plugin=self.select_training_type_plugin(),
precision_plugin=self.precision_plugin,
training_type_plugin=self.training_type_plugin,
)

def select_cluster_environment(self):
Expand Down
21 changes: 21 additions & 0 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from typing import Callable

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm


class TPUAccelerator(Accelerator):

Expand All @@ -17,3 +25,16 @@ def setup(self, trainer, model):
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def optimizer_step(
self, optimizer: torch.optim.Optimizer, current_epoch: int, batch_idx: int, opt_idx: int,
lambda_closure: Callable
):

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)

xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure})

self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
should_stop = False
else:
self.wait_count += 1
should_stop = self.wait_count >= self.patience
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def log(
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

accelerator = self.trainer.accelerator_backend
training_type_plugin = self.trainer.training_type_plugin

self._results.log(
Expand Down
19 changes: 4 additions & 15 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm


def is_lightning_optimizer(optimizer):
return isinstance(optimizer, LightningOptimizer)
Expand Down Expand Up @@ -133,18 +130,10 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
optimizer = self._optimizer
model = trainer.get_model()

if trainer._device_type == DeviceType.TPU:
with trainer.profiler.profile(profiler_name):
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})

# elif trainer.amp_backend is not None:
# # TODO: Adapt for new optimizer structure
# trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)

else:
with trainer.profiler.profile(profiler_name):
optimizer.step(closure=closure, *args, **kwargs)

with trainer.profiler.profile(profiler_name):
trainer.accelerator_backend.optimizer_step(*args, lambda_closure=closure, **kwargs)

# TODO: Do we need this?
accelerator_backend = trainer.accelerator_backend
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
if accelerator_backend.ddp_plugin.is_main_rpc_process:
Expand Down
21 changes: 8 additions & 13 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
Expand All @@ -19,17 +23,8 @@
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401

__all__ = [
"ApexMixedPrecisionPlugin",
"DataParallelPlugin",
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
"SingleDevicePlugin",
"SingleTPUPlugin",
"TPUHalfPrecisionPlugin",
"TPUSpawnPlugin",
"ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", "DDPPlugin", "DDPSpawnPlugin", "HorovodPlugin",
"NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", "SingleDevicePlugin",
"SingleTPUPlugin", "TPUHalfPrecisionPlugin", "TPUSpawnPlugin", 'RPCPlugin', 'RPCSequentialPlugin'
'TrainingTypePlugin', 'ParallelPlugin', 'Plugin', 'DDPShardedPlugin', 'DDPSpawnShardedPlugin'
]
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import List, Optional
Expand All @@ -22,7 +23,7 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp


class ParallelPlugin(TrainingTypePlugin, ABC):
Expand Down Expand Up @@ -102,3 +103,13 @@ def block_backward_sync(self):
yield self.model.no_sync()
else:
yield None

def broadcast(self, obj: object, src: int) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float)
data = all_gather_ddp_if_available(data_tensor)
buffer = io.BytesIO(data.cpu().byte().numpy())
obj = torch.load(buffer)
return obj
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
return dataloader

is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu

need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
is returned from `configure_optimizers`.
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from typing import cast, List, Optional, Type, TypeVar, Union, Any
from typing import Any, cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
Expand All @@ -46,6 +47,10 @@
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.model_utils import is_overridden


class TrainerProperties(ABC):

Expand Down
Loading