Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
Expand All @@ -21,10 +20,8 @@
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict

if torch.distributed.is_available():
Expand Down Expand Up @@ -222,6 +219,18 @@ def __setstate__(self, d):
def on_save(self, checkpoint):
return checkpoint

@property
def rpc_enabled(self):
return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin)

@property
def distributed_sampler_kwargs(self):
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want a default implementation for that? I.e. returning an empty duct here as default. I don't think having a property raising NotImplementedError in a property is good practice.


@property
def require_distributed_sampler(self):
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Can we make it return false per default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine for both. An accelerator shouldn't be used. It is metaclass.


@contextmanager
def block_ddp_plugin_sync_behaviour(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor

@property
def require_distributed_sampler(self):
return False
37 changes: 34 additions & 3 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available

Expand Down Expand Up @@ -101,9 +102,11 @@ def set_world_ranks(self, process_idx):
def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand Down Expand Up @@ -133,6 +136,9 @@ def ddp_train(self, process_idx, mp_queue, model):
# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# Initialize cuda device
self.init_device(process_idx)

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
Expand All @@ -143,6 +149,15 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -158,12 +173,14 @@ def ddp_train(self, process_idx, mp_queue, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand All @@ -189,7 +206,7 @@ def ddp_train(self, process_idx, mp_queue, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -219,3 +236,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
46 changes: 39 additions & 7 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -162,8 +163,11 @@ def _step(self, args):
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
if self.rpc_enabled:
# Allow RPC to handle barrier on main RPC processes
self.ddp_plugin.barrier()
elif torch_distrib.is_initialized():
torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group)

def _check_can_spawn_children(self):
if self._has_spawned_children:
Expand All @@ -177,9 +181,11 @@ def set_world_ranks(self, process_idx):
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank]
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand All @@ -192,12 +198,12 @@ def on_train_end(self):
def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
self.barrier('early_stopping')
should_stop = stop == self.trainer.world_size
return should_stop

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)
return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group)

def ddp_train(self, process_idx, model):
"""
Expand Down Expand Up @@ -226,6 +232,9 @@ def ddp_train(self, process_idx, model):
# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# Initialize cuda device
self.init_device(process_idx)

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
Expand All @@ -236,6 +245,15 @@ def ddp_train(self, process_idx, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -251,7 +269,7 @@ def ddp_train(self, process_idx, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down Expand Up @@ -284,7 +302,7 @@ def ddp_train(self, process_idx, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -317,3 +335,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
28 changes: 27 additions & 1 deletion pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
Expand Down Expand Up @@ -107,6 +108,15 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -128,6 +138,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand Down Expand Up @@ -221,7 +233,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
mp_queue.put(results)

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -251,3 +263,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
34 changes: 31 additions & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available

Expand Down Expand Up @@ -62,9 +63,11 @@ def set_world_ranks(self, process_idx):
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand Down Expand Up @@ -136,6 +139,15 @@ def ddp_train(self, process_idx, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -151,12 +163,14 @@ def ddp_train(self, process_idx, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand All @@ -183,7 +197,7 @@ def ddp_train(self, process_idx, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -213,3 +227,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
Loading