Skip to content

Commit 03bcc33

Browse files
authored
Merge b1136ac into 248a8e8
2 parents 248a8e8 + b1136ac commit 03bcc33

File tree

19 files changed

+287
-91
lines changed

19 files changed

+287
-91
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1616

1717

18+
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
19+
20+
1821
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
1922

2023

pytorch_lightning/plugins/environments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
15+
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
1516
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
1617
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401

pytorch_lightning/plugins/environments/cluster_environment.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,33 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import Optional
1416

1517

16-
class ClusterEnvironment:
18+
class ClusterEnvironment(ABC):
19+
""" Specification of a cluster environment. """
1720

18-
def __init__(self):
19-
self._world_size = None
21+
@abstractmethod
22+
def creates_children(self) -> bool:
23+
""" Whether the environment creates the subprocesses or not. """
2024

21-
def master_address(self):
22-
pass
25+
@abstractmethod
26+
def master_address(self) -> str:
27+
""" The master address through which all processes connect and communicate. """
2328

24-
def master_port(self):
25-
pass
29+
@abstractmethod
30+
def master_port(self) -> int:
31+
""" An open and configured port in the master node through which all processes communicate. """
2632

27-
def world_size(self) -> int:
28-
return self._world_size
33+
@abstractmethod
34+
def world_size(self) -> Optional[int]:
35+
""" The number of processes across all devices and nodes. """
2936

37+
@abstractmethod
3038
def local_rank(self) -> int:
31-
pass
39+
""" The rank (index) of the currently running process inside of the current node. """
3240

41+
@abstractmethod
3342
def node_rank(self) -> int:
34-
pass
43+
""" The rank (index) of the node on which the current process runs. """
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import socket
17+
from typing import Optional
18+
19+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
20+
21+
22+
class LightningEnvironment(ClusterEnvironment):
23+
"""
24+
The default environment used by Lightning for a single node or free cluster (not managed).
25+
26+
The master process must be launched by the user and Lightning will spawn new
27+
worker processes for distributed training, either in a single node or across multiple nodes.
28+
29+
If the master address and port are not provided, the default environment will choose them
30+
automatically. It is recommended to use this default environment for single-node distributed
31+
training as it provides the most convenient way to launch the training script.
32+
"""
33+
34+
def __init__(self):
35+
super().__init__()
36+
self._master_port = None
37+
38+
def creates_children(self) -> bool:
39+
return False
40+
41+
def master_address(self) -> str:
42+
return os.environ.get("MASTER_ADDR", "127.0.0.1")
43+
44+
def master_port(self) -> int:
45+
if self._master_port is None:
46+
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
47+
return int(self._master_port)
48+
49+
def world_size(self) -> Optional[int]:
50+
return None
51+
52+
def local_rank(self) -> int:
53+
return int(os.environ.get("LOCAL_RANK", 0))
54+
55+
def node_rank(self) -> int:
56+
group_rank = os.environ.get("GROUP_RANK", 0)
57+
return int(os.environ.get("NODE_RANK", group_rank))
58+
59+
60+
def find_free_network_port() -> int:
61+
"""
62+
Finds a free port on localhost.
63+
It is useful in single-node training when we don't want to connect to a real master node but
64+
have to set the `MASTER_PORT` environment variable.
65+
"""
66+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
67+
s.bind(("", 0))
68+
s.listen(1)
69+
port = s.getsockname()[1]
70+
s.close()
71+
return port

pytorch_lightning/plugins/environments/slurm_environment.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ class SLURMEnvironment(ClusterEnvironment):
2626
def __init__(self):
2727
super().__init__()
2828

29-
def master_address(self):
29+
def creates_children(self) -> bool:
30+
return True
31+
32+
def master_address(self) -> str:
3033
# figure out the root node addr
3134
slurm_nodelist = os.environ.get("SLURM_NODELIST")
3235
if slurm_nodelist:
@@ -39,7 +42,7 @@ def master_address(self):
3942
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
4043
return root_node
4144

42-
def master_port(self):
45+
def master_port(self) -> int:
4346
# -----------------------
4447
# SLURM JOB = PORT number
4548
# -----------------------
@@ -64,18 +67,18 @@ def master_port(self):
6467

6568
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
6669

67-
return default_port
70+
return int(default_port)
6871

6972
def world_size(self):
70-
return self._world_size
73+
return None
7174

72-
def local_rank(self):
75+
def local_rank(self) -> int:
7376
return int(os.environ['SLURM_LOCALID'])
7477

75-
def node_rank(self):
78+
def node_rank(self) -> int:
7679
return int(os.environ['SLURM_NODEID'])
7780

78-
def resolve_root_node_address(self, root_node):
81+
def resolve_root_node_address(self, root_node: str) -> str:
7982
if '[' in root_node:
8083
name, numbers = root_node.split('[', maxsplit=1)
8184
number = numbers.split(',', maxsplit=1)[0]

pytorch_lightning/plugins/environments/torchelastic_environment.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import os
17+
from typing import Optional
1718

1819
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
1920
from pytorch_lightning.utilities import rank_zero_warn
@@ -26,27 +27,31 @@ class TorchElasticEnvironment(ClusterEnvironment):
2627
def __init__(self):
2728
super().__init__()
2829

29-
def master_address(self):
30+
def creates_children(self) -> bool:
31+
return True
32+
33+
def master_address(self) -> str:
3034
if "MASTER_ADDR" not in os.environ:
3135
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
3236
os.environ["MASTER_ADDR"] = "127.0.0.1"
3337
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
3438
master_address = os.environ.get('MASTER_ADDR')
3539
return master_address
3640

37-
def master_port(self):
41+
def master_port(self) -> int:
3842
if "MASTER_PORT" not in os.environ:
3943
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
4044
os.environ["MASTER_PORT"] = "12910"
4145
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
4246

43-
port = os.environ.get('MASTER_PORT')
47+
port = int(os.environ.get('MASTER_PORT'))
4448
return port
4549

46-
def world_size(self):
47-
return os.environ.get('WORLD_SIZE')
50+
def world_size(self) -> Optional[int]:
51+
world_size = os.environ.get('WORLD_SIZE')
52+
return int(world_size) if world_size is not None else world_size
4853

49-
def local_rank(self):
54+
def local_rank(self) -> int:
5055
return int(os.environ['LOCAL_RANK'])
5156

5257
def node_rank(self) -> int:

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,14 @@
3030
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3131
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
3232
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
33-
from pytorch_lightning.utilities.distributed import (
34-
find_free_network_port,
35-
rank_zero_only,
36-
ReduceOp,
37-
sync_ddp_if_available,
38-
)
33+
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
3934
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4035
from pytorch_lightning.utilities.seed import seed_everything
4136

4237
if _HYDRA_AVAILABLE:
4338
from hydra.core.hydra_config import HydraConfig
4439
from hydra.utils import get_original_cwd, to_absolute_path
4540

46-
4741
log = logging.getLogger(__name__)
4842

4943

@@ -90,8 +84,7 @@ def setup(self, model):
9084
self._model = model
9185

9286
# start the other scripts
93-
# TODO: refactor and let generic cluster env hold the information about who spawns the processes
94-
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
87+
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
9588
self._call_children_scripts()
9689

9790
# set the task idx
@@ -105,15 +98,12 @@ def _call_children_scripts(self):
10598
self._has_spawned_children = True
10699

107100
# DDP Environment variables
108-
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
109-
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))
101+
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
102+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
110103

111104
# allow the user to pass the node rank
112-
node_rank = "0"
113-
node_rank = os.environ.get("NODE_RANK", node_rank)
114-
node_rank = os.environ.get("GROUP_RANK", node_rank)
115-
os.environ["NODE_RANK"] = node_rank
116-
os.environ["LOCAL_RANK"] = "0"
105+
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
106+
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
117107

118108
# when user is using hydra find the absolute path
119109
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
@@ -209,7 +199,6 @@ def determine_ddp_device_ids(self):
209199
return [self.root_device.index]
210200

211201
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
212-
# TODO: From where to get cluster environment?
213202
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
214203
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
215204
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,7 @@
3030
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
3131
from pytorch_lightning.utilities.cloud_io import atomic_save
3232
from pytorch_lightning.utilities.cloud_io import load as pl_load
33-
from pytorch_lightning.utilities.distributed import (
34-
find_free_network_port,
35-
rank_zero_only,
36-
rank_zero_warn,
37-
ReduceOp,
38-
sync_ddp_if_available,
39-
)
33+
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
4034
from pytorch_lightning.utilities.seed import seed_everything
4135

4236
log = logging.getLogger(__name__)
@@ -84,7 +78,7 @@ def distributed_sampler_kwargs(self):
8478
def setup(self, model):
8579
self._model = model
8680

87-
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))
81+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
8882

8983
# pass in a state q
9084
smp = mp.get_context("spawn")
@@ -93,7 +87,7 @@ def setup(self, model):
9387
def set_world_ranks(self, process_idx):
9488
self.local_rank = process_idx
9589
self.node_rank = self.cluster_environment.node_rank()
96-
self.task_idx = self.cluster_local_rank
90+
self.task_idx = self.cluster_environment.local_rank()
9791
self.global_rank = self.node_rank * self.num_processes + self.local_rank
9892
self.world_size = self.num_nodes * self.num_processes
9993

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,6 @@ def __init__(
4040
self.local_rank = 0
4141
self.cluster_environment = cluster_environment
4242

43-
@property
44-
def cluster_local_rank(self):
45-
try:
46-
return self.cluster_environment.local_rank()
47-
except KeyError:
48-
return 0
49-
5043
@property
5144
@abstractmethod
5245
def root_device(self):

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@
4242
TPUSpawnPlugin,
4343
TrainingTypePlugin,
4444
)
45-
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
45+
from pytorch_lightning.plugins.environments import (
46+
ClusterEnvironment,
47+
LightningEnvironment,
48+
SLURMEnvironment,
49+
TorchElasticEnvironment,
50+
)
4651
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
4752
from pytorch_lightning.utilities import (
4853
_APEX_AVAILABLE,
@@ -451,17 +456,10 @@ def select_cluster_environment(self) -> ClusterEnvironment:
451456
return self._cluster_environment
452457
if self.is_slurm_managing_tasks:
453458
env = SLURMEnvironment()
454-
# TODO: decouple DDP from SLURM
455-
# refactor and let generic cluster env hold the information about who spawns the processes
456-
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
457459
elif self.is_using_torchelastic:
458460
env = TorchElasticEnvironment()
459-
# TODO: decouple DDP from TE
460-
# refactor and let generic cluster env hold the information about who spawns the processes
461-
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
462461
else:
463-
# TODO: maybe introduce a DefaultEnvironment?
464-
env = TorchElasticEnvironment()
462+
env = LightningEnvironment()
465463
return env
466464

467465
def set_distributed_mode(self, distributed_backend: Optional[str] = None):

0 commit comments

Comments
 (0)