Skip to content

Commit ac71300

Browse files
committed
Clean up environment access in plugins (#6941)
1 parent 4c2005d commit ac71300

30 files changed

+715
-210
lines changed

CHANGELOG.md

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
3434

3535

36-
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
37-
38-
39-
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
40-
41-
42-
- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))
43-
44-
45-
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
46-
47-
48-
- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673))
49-
50-
51-
- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))
52-
53-
5436
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
5537

5638

@@ -240,6 +222,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
240222
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
241223

242224

225+
- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))
226+
227+
243228
## [1.2.7] - 2021-04-06
244229

245230
### Fixed

pytorch_lightning/plugins/environments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
15+
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
1516
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
1617
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401

pytorch_lightning/plugins/environments/cluster_environment.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,44 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from abc import ABC, abstractmethod
1415

1516

16-
class ClusterEnvironment:
17+
class ClusterEnvironment(ABC):
18+
""" Specification of a cluster environment. """
1719

18-
def __init__(self):
19-
self._world_size = None
20+
@abstractmethod
21+
def creates_children(self) -> bool:
22+
""" Whether the environment creates the subprocesses or not. """
2023

21-
def master_address(self):
22-
pass
24+
@abstractmethod
25+
def master_address(self) -> str:
26+
""" The master address through which all processes connect and communicate. """
2327

24-
def master_port(self):
25-
pass
28+
@abstractmethod
29+
def master_port(self) -> int:
30+
""" An open and configured port in the master node through which all processes communicate. """
2631

32+
@abstractmethod
2733
def world_size(self) -> int:
28-
return self._world_size
34+
""" The number of processes across all devices and nodes. """
2935

30-
def local_rank(self) -> int:
36+
@abstractmethod
37+
def set_world_size(self, size: int) -> None:
3138
pass
3239

33-
def node_rank(self) -> int:
40+
@abstractmethod
41+
def global_rank(self) -> int:
42+
""" The rank (index) of the currently running process across all nodes and devices. """
43+
44+
@abstractmethod
45+
def set_global_rank(self, rank: int) -> None:
3446
pass
47+
48+
@abstractmethod
49+
def local_rank(self) -> int:
50+
""" The rank (index) of the currently running process inside of the current node. """
51+
52+
@abstractmethod
53+
def node_rank(self) -> int:
54+
""" The rank (index) of the node on which the current process runs. """
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import socket
17+
18+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
19+
from pytorch_lightning.utilities import rank_zero_only
20+
21+
22+
class LightningEnvironment(ClusterEnvironment):
23+
"""
24+
The default environment used by Lightning for a single node or free cluster (not managed).
25+
26+
The master process must be launched by the user and Lightning will spawn new
27+
worker processes for distributed training, either in a single node or across multiple nodes.
28+
29+
If the master address and port are not provided, the default environment will choose them
30+
automatically. It is recommended to use this default environment for single-node distributed
31+
training as it provides the most convenient way to launch the training script.
32+
"""
33+
34+
def __init__(self):
35+
super().__init__()
36+
self._master_port = None
37+
self._global_rank: int = 0
38+
self._world_size: int = 1
39+
40+
def creates_children(self) -> bool:
41+
return False
42+
43+
def master_address(self) -> str:
44+
return os.environ.get("MASTER_ADDR", "127.0.0.1")
45+
46+
def master_port(self) -> int:
47+
if self._master_port is None:
48+
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
49+
return int(self._master_port)
50+
51+
def world_size(self) -> int:
52+
return self._world_size
53+
54+
def set_world_size(self, size: int) -> None:
55+
self._world_size = size
56+
57+
def global_rank(self) -> int:
58+
return self._global_rank
59+
60+
def set_global_rank(self, rank: int) -> None:
61+
self._global_rank = rank
62+
rank_zero_only.rank = rank
63+
64+
def local_rank(self) -> int:
65+
return int(os.environ.get("LOCAL_RANK", 0))
66+
67+
def node_rank(self) -> int:
68+
group_rank = os.environ.get("GROUP_RANK", 0)
69+
return int(os.environ.get("NODE_RANK", group_rank))
70+
71+
72+
def find_free_network_port() -> int:
73+
"""
74+
Finds a free port on localhost.
75+
It is useful in single-node training when we don't want to connect to a real master node but
76+
have to set the `MASTER_PORT` environment variable.
77+
"""
78+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
79+
s.bind(("", 0))
80+
s.listen(1)
81+
port = s.getsockname()[1]
82+
s.close()
83+
return port

pytorch_lightning/plugins/environments/slurm_environment.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
class SLURMEnvironment(ClusterEnvironment):
2525

26-
def __init__(self):
27-
super().__init__()
26+
def creates_children(self) -> bool:
27+
return True
2828

29-
def master_address(self):
29+
def master_address(self) -> str:
3030
# figure out the root node addr
3131
slurm_nodelist = os.environ.get("SLURM_NODELIST")
3232
if slurm_nodelist:
@@ -39,7 +39,7 @@ def master_address(self):
3939
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
4040
return root_node
4141

42-
def master_port(self):
42+
def master_port(self) -> int:
4343
# -----------------------
4444
# SLURM JOB = PORT number
4545
# -----------------------
@@ -64,18 +64,27 @@ def master_port(self):
6464

6565
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
6666

67-
return default_port
67+
return int(default_port)
6868

69-
def world_size(self):
70-
return self._world_size
69+
def world_size(self) -> int:
70+
return int(os.environ["SLURM_NTASKS"])
7171

72-
def local_rank(self):
72+
def set_world_size(self, size: int) -> None:
73+
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
74+
75+
def global_rank(self) -> int:
76+
return int(os.environ["SLURM_PROCID"])
77+
78+
def set_global_rank(self, rank: int) -> None:
79+
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
80+
81+
def local_rank(self) -> int:
7382
return int(os.environ['SLURM_LOCALID'])
7483

75-
def node_rank(self):
84+
def node_rank(self) -> int:
7685
return int(os.environ['SLURM_NODEID'])
7786

78-
def resolve_root_node_address(self, root_node):
87+
def resolve_root_node_address(self, root_node: str) -> str:
7988
if '[' in root_node:
8089
name, numbers = root_node.split('[', maxsplit=1)
8190
number = numbers.split(',', maxsplit=1)[0]

pytorch_lightning/plugins/environments/torchelastic_environment.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import os
17+
from typing import Optional
1718

1819
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
1920
from pytorch_lightning.utilities import rank_zero_warn
@@ -23,30 +24,48 @@
2324

2425
class TorchElasticEnvironment(ClusterEnvironment):
2526

26-
def __init__(self):
27-
super().__init__()
27+
@staticmethod
28+
def is_using_torchelastic() -> bool:
29+
""" Returns ``True`` if the current process was launched using the torchelastic command. """
30+
required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
31+
return all(v in os.environ for v in required_env_vars)
2832

29-
def master_address(self):
33+
def creates_children(self) -> bool:
34+
return True
35+
36+
def master_address(self) -> str:
3037
if "MASTER_ADDR" not in os.environ:
3138
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
3239
os.environ["MASTER_ADDR"] = "127.0.0.1"
3340
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
3441
master_address = os.environ.get('MASTER_ADDR')
3542
return master_address
3643

37-
def master_port(self):
44+
def master_port(self) -> int:
3845
if "MASTER_PORT" not in os.environ:
3946
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
4047
os.environ["MASTER_PORT"] = "12910"
4148
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
4249

43-
port = os.environ.get('MASTER_PORT')
50+
port = int(os.environ.get('MASTER_PORT'))
4451
return port
4552

46-
def world_size(self):
47-
return os.environ.get('WORLD_SIZE')
53+
def world_size(self) -> Optional[int]:
54+
world_size = os.environ.get('WORLD_SIZE')
55+
return int(world_size) if world_size is not None else world_size
56+
57+
def set_world_size(self, size: int) -> None:
58+
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
59+
60+
def global_rank(self) -> int:
61+
return int(os.environ["RANK"])
62+
63+
def set_global_rank(self, rank: int) -> None:
64+
log.debug(
65+
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
66+
)
4867

49-
def local_rank(self):
68+
def local_rank(self) -> int:
5069
return int(os.environ['LOCAL_RANK'])
5170

5271
def node_rank(self) -> int:

0 commit comments

Comments
 (0)