Skip to content

Commit 987530c

Browse files
shuyingsunshine21tchatonkaushikb11carmoccaawaelchli
authored
Set num_nodes and sync_batchnorm From Trainer for Manually Passed Training Type Plugin (#7026)
Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 710b144 commit 987530c

File tree

8 files changed

+139
-44
lines changed

8 files changed

+139
-44
lines changed

CHANGELOG.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))
2323

2424

25+
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))
26+
27+
28+
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
29+
30+
2531
### Deprecated
2632

2733

2834
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
2935

3036

37+
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
38+
39+
3140
### Removed
3241

3342

@@ -144,9 +153,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
144153
- Ensure accelerator is valid if running interactively ([#5970](https://github.com/PyTorchLightning/pytorch-lightning/pull/5970))
145154
- Disabled batch transfer in DP mode ([#6098](https://github.com/PyTorchLightning/pytorch-lightning/pull/6098))
146155

147-
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))
148-
149-
150156
### Deprecated
151157

152158
- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_HYDRA_AVAILABLE,
3434
_TORCH_GREATER_EQUAL_1_7,
3535
_TORCH_GREATER_EQUAL_1_8,
36+
rank_zero_deprecation,
3637
rank_zero_warn,
3738
)
3839
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
@@ -62,23 +63,33 @@ class DDPPlugin(ParallelPlugin):
6263
def __init__(
6364
self,
6465
parallel_devices: Optional[List[torch.device]] = None,
65-
num_nodes: int = 1,
66+
num_nodes: Optional[int] = None,
6667
cluster_environment: ClusterEnvironment = None,
67-
sync_batchnorm: bool = False,
68+
sync_batchnorm: Optional[bool] = None,
6869
ddp_comm_state: Optional[object] = None,
6970
ddp_comm_hook: Optional[callable] = None,
7071
ddp_comm_wrapper: Optional[callable] = None,
7172
**kwargs: Union[Any, Dict[str, Any]],
7273
) -> None:
7374
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
7475
self.interactive_ddp_procs = []
75-
self.num_nodes = num_nodes
76-
self.sync_batchnorm = sync_batchnorm
76+
if num_nodes is not None:
77+
rank_zero_deprecation(
78+
"Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
79+
" Notice that it will be overriden by the trainer setting."
80+
)
81+
self._num_nodes = num_nodes or 1
82+
if sync_batchnorm is not None:
83+
rank_zero_deprecation(
84+
"Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
85+
" Notice that it will be overriden by the trainer setting."
86+
)
87+
self._sync_batchnorm = sync_batchnorm or False
7788
self.dist = LightningDistributed()
89+
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
7890
self._ddp_kwargs = kwargs
7991
self._has_spawned_children = False
8092
self.task_idx = None
81-
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
8293
self._ddp_comm_state = ddp_comm_state
8394
self._ddp_comm_hook = ddp_comm_hook
8495
self._ddp_comm_wrapper = ddp_comm_wrapper
@@ -88,6 +99,24 @@ def __init__(
8899
def root_device(self):
89100
return self.parallel_devices[self.local_rank]
90101

102+
@property
103+
def num_nodes(self) -> int:
104+
return self._num_nodes
105+
106+
@num_nodes.setter
107+
def num_nodes(self, num_nodes: int) -> None:
108+
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
109+
self._num_nodes = num_nodes
110+
self.set_world_ranks()
111+
112+
@property
113+
def sync_batchnorm(self) -> bool:
114+
return self._sync_batchnorm
115+
116+
@sync_batchnorm.setter
117+
def sync_batchnorm(self, sync_batchnorm: bool) -> None:
118+
self._sync_batchnorm = sync_batchnorm
119+
91120
@property
92121
def distributed_sampler_kwargs(self):
93122
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
@@ -212,10 +241,11 @@ def _check_can_spawn_children(self):
212241
)
213242

214243
def set_world_ranks(self) -> None:
215-
if self.cluster_environment is not None:
216-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
217-
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
218-
rank_zero_only.rank = self.cluster_environment.global_rank()
244+
if self.cluster_environment is None:
245+
return
246+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
247+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
248+
rank_zero_only.rank = self.cluster_environment.global_rank()
219249

220250
def pre_configure_ddp(self):
221251
# if unset, default `find_unused_parameters` `True`

pytorch_lightning/plugins/training_type/ddp2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def distributed_sampler_kwargs(self):
7272
def _is_single_process_single_device(self) -> bool:
7373
return False
7474

75-
def set_world_ranks(self):
75+
def set_world_ranks(self) -> None:
76+
if self.cluster_environment is None:
77+
return
7678
self.cluster_environment.set_global_rank(self.node_rank)
7779
self.cluster_environment.set_world_size(self.num_nodes)

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
3232
from pytorch_lightning.utilities.cloud_io import atomic_save
3333
from pytorch_lightning.utilities.cloud_io import load as pl_load
34-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
34+
from pytorch_lightning.utilities.distributed import (
35+
rank_zero_deprecation,
36+
rank_zero_only,
37+
rank_zero_warn,
38+
ReduceOp,
39+
sync_ddp_if_available,
40+
)
3541
from pytorch_lightning.utilities.seed import reset_seed
3642

3743
if _TORCH_GREATER_EQUAL_1_8:
@@ -51,17 +57,27 @@ class DDPSpawnPlugin(ParallelPlugin):
5157
def __init__(
5258
self,
5359
parallel_devices: Optional[List[torch.device]] = None,
54-
num_nodes: int = 1,
60+
num_nodes: Optional[int] = None,
5561
cluster_environment: ClusterEnvironment = None,
56-
sync_batchnorm: bool = False,
62+
sync_batchnorm: Optional[bool] = None,
5763
ddp_comm_state: Optional[object] = None,
5864
ddp_comm_hook: Optional[callable] = None,
5965
ddp_comm_wrapper: Optional[callable] = None,
6066
**kwargs: Any,
6167
):
6268
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
63-
self.num_nodes = num_nodes
64-
self.sync_batchnorm = sync_batchnorm
69+
if num_nodes is not None:
70+
rank_zero_deprecation(
71+
"Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. "
72+
"Notice that it will be overriden by the trainer setting."
73+
)
74+
self._num_nodes = num_nodes or 1
75+
if sync_batchnorm is not None:
76+
rank_zero_deprecation(
77+
"Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. "
78+
"Notice that it will be overriden by the trainer setting."
79+
)
80+
self._sync_batchnorm = sync_batchnorm or False
6581
self._ddp_kwargs = kwargs
6682
self.dist = LightningDistributed()
6783
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
@@ -72,6 +88,24 @@ def __init__(
7288
self._local_rank = 0
7389
self.set_world_ranks()
7490

91+
@property
92+
def num_nodes(self) -> int:
93+
return self._num_nodes
94+
95+
@num_nodes.setter
96+
def num_nodes(self, num_nodes: int) -> None:
97+
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
98+
self._num_nodes = num_nodes
99+
self.set_world_ranks()
100+
101+
@property
102+
def sync_batchnorm(self) -> bool:
103+
return self._sync_batchnorm
104+
105+
@sync_batchnorm.setter
106+
def sync_batchnorm(self, sync_batchnorm: bool) -> None:
107+
self._sync_batchnorm = sync_batchnorm
108+
75109
@property
76110
def local_rank(self) -> int:
77111
return self._local_rank
@@ -106,10 +140,11 @@ def setup(self, model):
106140

107141
def set_world_ranks(self, process_idx: int = 0) -> None:
108142
self._local_rank = process_idx
109-
if self.cluster_environment is not None:
110-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
111-
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
112-
rank_zero_only.rank = self.cluster_environment.global_rank()
143+
if self.cluster_environment is None:
144+
return
145+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
146+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
147+
rank_zero_only.rank = self.cluster_environment.global_rank()
113148

114149
@property
115150
def mp_spawn_kwargs(self):

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
logging_batch_size_per_gpu: Union[str, int] = "auto",
9292
config: Optional[Union[Path, str, dict]] = None,
9393
logging_level: int = logging.WARN,
94-
num_nodes: int = 1,
94+
num_nodes: Optional[int] = None,
9595
parallel_devices: Optional[List[torch.device]] = None,
9696
cluster_environment: Optional[ClusterEnvironment] = None,
9797
loss_scale: float = 0,

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(
133133

134134
self.handle_given_plugins()
135135

136+
self._training_type_plugin_resolved = False
136137
self.accelerator = self.select_accelerator()
137138

138139
# override dist backend when using tpus
@@ -222,10 +223,13 @@ def precision_plugin(self) -> PrecisionPlugin:
222223

223224
@property
224225
def training_type_plugin(self) -> TrainingTypePlugin:
226+
if self._training_type_plugin_resolved:
227+
# avoid calling `resolve_training_type_plugin` multiple times
228+
return self._training_type_plugin
225229
if self._training_type_plugin is None:
226230
self._training_type_plugin = self.select_training_type_plugin()
227-
else:
228-
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
231+
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
232+
self._training_type_plugin_resolved = True
229233

230234
return self._training_type_plugin
231235

@@ -320,7 +324,6 @@ def is_using_torchelastic(self) -> bool:
320324
"""
321325
.. deprecated:: v1.3
322326
Will be removed in v1.5.0.
323-
324327
Returns:
325328
``True`` if the current process was launched using the torchelastic command.
326329
"""
@@ -385,15 +388,11 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
385388
if self.use_ddp2:
386389
plugin = DDP2Plugin(
387390
parallel_devices=self.parallel_devices,
388-
num_nodes=self.num_nodes,
389391
cluster_environment=self.cluster_environment,
390-
sync_batchnorm=self.sync_batchnorm,
391392
)
392393
elif self.use_ddp and self.use_deepspeed:
393394
plugin = DeepSpeedPlugin(
394-
num_nodes=self.num_nodes,
395-
cluster_environment=self.select_cluster_environment(),
396-
parallel_devices=self.parallel_devices
395+
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
397396
)
398397
elif self.use_ddp:
399398
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
@@ -426,9 +425,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
426425

427426
plugin = ddp_plugin_cls(
428427
parallel_devices=self.parallel_devices,
429-
num_nodes=self.num_nodes,
430428
cluster_environment=self.cluster_environment,
431-
sync_batchnorm=self.sync_batchnorm,
432429
)
433430
elif self.use_dp:
434431
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
@@ -443,20 +440,20 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
443440

444441
def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
445442
# necessary for when the user has passed in a plugin
446-
if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
443+
if hasattr(training_type, 'parallel_devices') and getattr(training_type, 'parallel_devices') is None:
447444
training_type.parallel_devices = self.parallel_devices
448445
if hasattr(training_type, 'num_processes'):
449446
training_type.num_processes = len(self.parallel_devices)
450447

451448
if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None:
452449
training_type.cluster_environment = self.select_cluster_environment()
453450

454-
if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
451+
if hasattr(training_type, 'num_nodes'):
452+
# set num_nodes for training_type from trainer setting
455453
training_type.num_nodes = self.num_nodes
456454

457-
# Automatically set sync_batchnorm if None.
458-
# Useful for custom plugins.
459-
if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None:
455+
if hasattr(training_type, 'sync_batchnorm'):
456+
# set sync_batchnorm for training_type from trainer setting
460457
training_type.sync_batchnorm = self.sync_batchnorm
461458

462459
return training_type

tests/deprecated_api/test_remove_1-6.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717

1818
from pytorch_lightning import Trainer
19+
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
1920
from tests.helpers import BoringModel
2021

2122

@@ -28,3 +29,23 @@ def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
2829

2930
with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
3031
trainer.has_arg("training_step", "batch")
32+
33+
34+
def test_v1_6_0_ddp_num_nodes():
35+
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
36+
DDPPlugin(num_nodes=1)
37+
38+
39+
def test_v1_6_0_ddp_sync_batchnorm():
40+
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
41+
DDPPlugin(sync_batchnorm=False)
42+
43+
44+
def test_v1_6_0_ddp_spawn_num_nodes():
45+
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
46+
DDPSpawnPlugin(num_nodes=1)
47+
48+
49+
def test_v1_6_0_ddp_spawn_sync_batchnorm():
50+
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
51+
DDPSpawnPlugin(sync_batchnorm=False)

tests/plugins/test_cluster_integration.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,14 @@ def environment_combinations():
6060

6161

6262
@pytest.mark.parametrize(
63-
"plugin_cls", [
63+
"plugin_cls",
64+
[
6465
DDPPlugin,
6566
DDPShardedPlugin,
6667
DDP2Plugin,
6768
pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
6869
pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)),
69-
]
70+
],
7071
)
7172
def test_ranks_available_manual_plugin_selection(plugin_cls):
7273
""" Test that the rank information is readily available after Trainer initialization. """
@@ -79,10 +80,12 @@ def test_ranks_available_manual_plugin_selection(plugin_cls):
7980
with mock.patch.dict(os.environ, variables):
8081
plugin = plugin_cls(
8182
parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)],
82-
num_nodes=num_nodes,
8383
cluster_environment=cluster,
8484
)
85-
trainer = Trainer(plugins=[plugin])
85+
trainer = Trainer(
86+
plugins=[plugin],
87+
num_nodes=num_nodes,
88+
)
8689
assert rank_zero_only.rank == expected["global_rank"]
8790
assert trainer.global_rank == expected["global_rank"]
8891
assert trainer.local_rank == expected["local_rank"]
@@ -91,13 +94,14 @@ def test_ranks_available_manual_plugin_selection(plugin_cls):
9194

9295

9396
@pytest.mark.parametrize(
94-
"trainer_kwargs", [
97+
"trainer_kwargs",
98+
[
9599
dict(accelerator="ddp", gpus=[1, 2]),
96100
dict(accelerator="ddp_sharded", gpus=[1, 2]),
97101
dict(accelerator="ddp2", gpus=[1, 2]),
98102
dict(accelerator="ddp_cpu", num_processes=2),
99103
dict(accelerator="ddp_spawn", gpus=[1, 2]),
100-
]
104+
],
101105
)
102106
@mock.patch("torch.cuda.is_available", return_value=True)
103107
@mock.patch("torch.cuda.device_count", return_value=4)

0 commit comments

Comments
 (0)