Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 3 additions & 51 deletions pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DEFAULT_RPC_TIMEOUT_SEC = 60.
if _RPC_AVAILABLE:
from torch.distributed import rpc

with suppress(ModuleNotFoundError, ImportError):
from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC

Expand Down Expand Up @@ -76,60 +77,11 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> No
"""
raise NotImplementedError

def on_main_rpc_connection(self, trainer) -> None:
"""
Called when main rpc connection has been established.

Args:
trainer: The trainer object.
"""
raise NotImplementedError

def on_accelerator_exit_rpc_process(self) -> None:
"""
Called to exit RPC process within the accelerator, that is being managed by main process.

Args:
trainer: The trainer object.
"""
self.exit_rpc_process()

def exit_rpc_process(self):
if self._is_rpc_initialized:
torch.distributed.rpc.shutdown()
self._is_rpc_initialized = False

@property
def return_after_exit_rpc_process(self) -> bool:
"""
Override to decide whether to skip train/test function after shutdown completed.
Usually RPC shutdown is a join/exit function, afterwards we want to exit the process.

Returns:
Whether to return after RPC exit.
"""
raise NotImplementedError

def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None:
"""
Called when optimizer step is run on the main process. Used to signal any RPC workers to run optimizer step.

Args:
model: The LightningModule.
opt_idx: The idx of the optimizer to carry out step on.
"""
raise NotImplementedError

@property
def is_main_rpc_process(self) -> bool:
"""
Override to add logic to determine current process is main RPC process.
"""
raise NotImplementedError

def barrier(self, name: Optional[str] = None) -> None:
"""
Override to define distributed sync communication. This needs to be handled differently due to
the RPC connection managing certain processes at the same time.
"""
raise NotImplementedError
def rpc_enabled(self) -> bool:
return True
96 changes: 53 additions & 43 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License
import logging
import os
from typing import Any, List, Optional, Sequence
from typing import List, Optional

import torch
import torch.distributed as torch_distrib
Expand All @@ -22,8 +22,7 @@
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
Expand Down Expand Up @@ -97,15 +96,18 @@ def __init__(
self.checkpoint = checkpoint
self.balance_mode = balance_mode
self.pipelined_backward = pipelined_backward
self.main_rpc_process = False # Updated by main process, default for all secondary processes
self._main_rpc_process = True

def init_ddp_connection(
self,
global_rank: int,
world_size: int,
) -> None:
# what is this used for?
self.prepared_for_backwards = False
if self.lightning_module.trainer.amp_backend is not None:
raise MisconfigurationException(
'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision'
)

if self._skip_init_connections():
return
super().init_ddp_connection(
Expand All @@ -119,21 +121,18 @@ def init_ddp_connection(
self.set_main_rpc_process()

self._check_sequential_model_exists(model)

# check if user given balance is valid
if self.balance is not None:
self._assert_valid_model_balance()

if self.main_rpc_process:
if self.balance is None:
self._infer_model_balance()
self._assert_valid_model_balance()

if not self.is_main_rpc_process:
self.on_accelerator_exit_rpc_process()
self.exit_rpc_process()
if self.return_after_exit_rpc_process:
return
self.init_pipe_module()
else:
self.on_main_rpc_connection()

def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
pass
self.handle_transferred_pipe_module()
self.exit_rpc_process()

def _infer_model_balance(self):
log.info(f'Inferring model balance using {self.balance_mode} mode')
Expand Down Expand Up @@ -231,43 +230,40 @@ def _infer_check_num_gpus(self):
# Assume that the user wants to balance his model on all GPUs
return self.world_size

def on_accelerator_exit_rpc_process(self) -> None:
def handle_transferred_pipe_module(self) -> None:
if not self.lightning_module.running_stage == RunningStage.TESTING:
torch_distrib.barrier() # Ensure we await main process initialization

# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
del rpc_pipe.PipeModel.trainer.model.sequential_module
rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel
rpc_pipe.PipeModel.configure_optimizers = self.lightning_module.configure_optimizers
super().on_accelerator_exit_rpc_process()

def set_main_rpc_process(self):
self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0

def on_main_rpc_connection(self) -> None:
def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if not self.lightning_module.running_stage == RunningStage.TESTING:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

# TODO: Move this to the connector
def _check_arguments(self, trainer):
if trainer.amp_backend is not None:
raise MisconfigurationException(
'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
)
# TODO: Move this to the connector

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""

def configure_ddp(self) -> None:
# process_group=mpu.get_data_parallel_group()
super().configure_ddp()
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
self._model.require_backward_grad_sync = False
def configure_ddp(self):
if self.main_rpc_process:
self.pre_configure_ddp()

self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
process_group=mpu.get_data_parallel_group(),
**self._ddp_kwargs,
)
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
self._model.require_backward_grad_sync = False

@rank_zero_only
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
Expand Down Expand Up @@ -302,16 +298,19 @@ def distributed_sampler_kwargs(self):
def data_parallel_group(self):
return mpu.get_data_parallel_group()

@property
def is_main_rpc_process(self) -> bool:
return self.main_rpc_process
def set_main_rpc_process(self):
self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0

@property
def return_after_exit_rpc_process(self) -> bool:
return True
def main_rpc_process(self) -> bool:
return self._main_rpc_process

@main_rpc_process.setter
def main_rpc_process(self, is_main_process):
self._main_rpc_process = is_main_process

def barrier(self, name: Optional[str] = None) -> None:
if torch_distrib.is_initialized() and self.is_main_rpc_process:
if torch_distrib.is_initialized() and self.main_rpc_process:
torch_distrib.barrier(group=self.data_parallel_group)

def _check_pipe_available(self):
Expand All @@ -322,11 +321,22 @@ def _check_pipe_available(self):

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None:
"""Hook to do something after each optimizer step."""
if self.rpc_enabled and self.is_main_rpc_process:

if self.rpc_enabled and self.main_rpc_process:
# Initialize optimizer step on main process
self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs)

def post_training(self):
if self.main_rpc_process:
super().post_training()

def start_training(self, trainer: 'Trainer') -> None:
if self.main_rpc_process:
super().start_training(trainer)

def start_testing(self, trainer: 'Trainer') -> None:
if self.main_rpc_process:
super().start_testing(trainer)


class LightningPipeModule(nn.Module):
"""
Expand Down
1 change: 0 additions & 1 deletion tests/plugins/legacy/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,13 @@
from tests.helpers.boring_model import RandomDataset


def cleanup(ctx, model):
"""
Cleanup function required to ensure we delete the pipe module at the end of the the test on all workers
"""
del model


@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
def test_rpc_sequential_plugin_manual(tmpdir, args=None):
model = SequentialModelRPCManual()
trainer = Trainer(
max_epochs=2,
Expand All @@ -54,12 +47,12 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):

trainer.fit(model)

if torch_distrib.get_rank() == 0:
if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0

if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
trainer.accelerator_backend.training_type_plugin.exit_rpc_process()


@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
Expand All @@ -68,7 +61,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
def test_rpc_sequential_plugin_manual_amp(tmpdir, args=None):
model = SequentialModelRPCManual()
trainer = Trainer(
max_epochs=2,
Expand All @@ -81,22 +74,19 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
distributed_backend="ddp",
plugins=[RPCSequentialPlugin(balance=[2, 1])],
)
try:
with pytest.raises(
MisconfigurationException, match='RPCSequentialPlugin is currently not supported in Automatic Mixed Precision'
):
trainer.fit(model)

assert len(trainer.dev_debugger.pbar_added_metrics) > 0

except MisconfigurationException as e:
assert str(e) == 'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision'


@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
def test_rpc_sequential_plugin_automatic(tmpdir, args=None):
model = SequentialModelRPCAutomatic()
trainer = Trainer(
max_epochs=2,
Expand All @@ -110,13 +100,12 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):

trainer.fit(model)

if torch_distrib.get_rank() == 0:
if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0

if trainer.accelerator_backend.rpc_enabled:

# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
trainer.accelerator_backend.training_type_plugin.exit_rpc_process()


@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
Expand All @@ -125,7 +114,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None):
def test_rpc_sequential_plugin_with_wrong_balance(tmpdir, args=None):
model = SequentialModelRPCAutomatic()
trainer = Trainer(
max_epochs=2,
Expand All @@ -137,15 +126,14 @@ def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None):
plugins=[RPCSequentialPlugin(balance=[2, 2])],
)

try:
with pytest.raises(
MisconfigurationException, match="The provided balance sum: 4 does not match your Sequential length: 3"
):
trainer.fit(model)

except MisconfigurationException as e:
assert str(e) == 'The provided balance sum: 4 does not match your Sequential length: 3'

if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
trainer.accelerator_backend.training_type_plugin.exit_rpc_process()


class SequentialModelRPCManual(LightningModule):
Expand Down
Loading