Skip to content

Commit 8be8909

Browse files
committed
base cluster env
1 parent c912c4b commit 8be8909

File tree

16 files changed

+121
-38
lines changed

16 files changed

+121
-38
lines changed

azure-pipelines.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ jobs:
7272
displayName: 'Get legacy checkpoints'
7373
7474
- script: |
75-
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50
75+
# python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50
76+
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests --ignore tests/plugins/test_sharded_plugin.py --ignore tests/trainer/test_dataloaders.py --ignore tests/metrics -v --durations=50
77+
# Todo: Find why those tests are failing when run in the main pytest.
78+
python -m coverage run -a --source pytorch_lightning -m pytest tests/metrics -v --durations=50
79+
python -m coverage run -a --source pytorch_lightning -m pytest tests/plugins/test_sharded_plugin.py tests/trainer/test_dataloaders.py -v --durations=50
80+
7681
displayName: 'Testing: standard'
7782
7883
- script: |

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
3131
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
3232

33+
if TYPE_CHECKING:
34+
from pytorch_lightning.trainer.trainer import Trainer
35+
3336

3437
class Accelerator(object):
3538
"""

pytorch_lightning/accelerators/accelerator_connector.py

100644100755
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def select_accelerator(self) -> Accelerator:
428428
training_type_plugin=self.training_type_plugin,
429429
)
430430

431-
def select_cluster_environment(self) -> ClusterEnvironment:
431+
def select_cluster_environment(self):
432432
if self._cluster_environment is not None:
433433
return self._cluster_environment
434434
if self.is_slurm_managing_tasks:
@@ -438,12 +438,8 @@ def select_cluster_environment(self) -> ClusterEnvironment:
438438
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
439439
elif self.is_using_torchelastic:
440440
env = TorchElasticEnvironment()
441-
# TODO: decouple DDP from TE
442-
# refactor and let generic cluster env hold the information about who spawns the processes
443-
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
444441
else:
445-
# TODO: maybe introduce a DefaultEnvironment?
446-
env = TorchElasticEnvironment()
442+
env = DefaultEnvironment()
447443
return env
448444

449445
def set_distributed_mode(self, distributed_backend: Optional[str] = None):

pytorch_lightning/plugins/environments/cluster_environment.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,36 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import Optional
1416

15-
from pytorch_lightning.plugins.legacy.plugin import LightningPlugin
1617

18+
class ClusterEnvironment(ABC):
19+
""" Specification of a cluster environment. """
1720

18-
class ClusterEnvironment(LightningPlugin):
21+
@abstractmethod
22+
def spawns_children(self) -> bool:
23+
""" Whether the environment spawns the subprocesses or not. """
1924

20-
def __init__(self):
21-
self._world_size = None
25+
@abstractmethod
26+
def master_address(self) -> str:
27+
""" The master address through which all processes connect and communicate. """
2228

23-
def master_address(self):
24-
pass
25-
26-
def master_port(self):
27-
pass
29+
@abstractmethod
30+
def master_port(self) -> int:
31+
""" An open and configured port in the master node through which all processes communicate. """
2832

29-
def world_size(self) -> int:
30-
return self._world_size
33+
@abstractmethod
34+
def world_size(self) -> Optional[int]:
35+
""" The number of processes across all devices and nodes. """
3136

37+
@abstractmethod
3238
def local_rank(self) -> int:
33-
pass
39+
""" The rank (index) of the currently running process inside of the current node. """
40+
41+
@abstractmethod
42+
def node_rank(self) -> int:
43+
""" The rank (index) of the node on which the current process runs. """
3444

3545
def node_rank(self) -> int:
3646
pass
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import Optional
17+
18+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
19+
from pytorch_lightning.utilities.distributed import find_free_network_port
20+
21+
22+
class DefaultEnvironment(ClusterEnvironment):
23+
"""
24+
A default environment for a single node or free cluster (not managed).
25+
26+
The master process must be launched by the user and Lightning will spawn new
27+
worker processes for distributed training, either in a single node or across multiple nodes.
28+
29+
If the master address and port are not provided, the default environment will choose them
30+
automatically. It is recommended to use this default environment for single-node distributed
31+
training as it provides the most convenient way to launch the training script.
32+
"""
33+
34+
def __init__(self):
35+
super().__init__()
36+
self._world_size = None
37+
38+
def spawns_children(self) -> bool:
39+
return False
40+
41+
def master_address(self) -> str:
42+
return os.environ.get("MASTER_ADDR", "127.0.0.1")
43+
44+
def master_port(self) -> int:
45+
return int(os.environ.get("MASTER_PORT", find_free_network_port()))
46+
47+
def world_size(self) -> Optional[int]:
48+
return None
49+
50+
def local_rank(self) -> int:
51+
return int(os.environ.get("LOCAL_RANK", 0))
52+
53+
def node_rank(self) -> int:
54+
group_rank = os.environ.get("GROUP_RANK", 0)
55+
return int(os.environ.get("NODE_RANK", group_rank))

pytorch_lightning/plugins/environments/slurm_environment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class SLURMEnvironment(ClusterEnvironment):
2424
def __init__(self):
2525
super().__init__()
2626

27+
def spawns_children(self) -> bool:
28+
return True
29+
2730
def master_address(self):
2831
# figure out the root node addr
2932
slurm_nodelist = os.environ.get("SLURM_NODELIST")
@@ -65,7 +68,7 @@ def master_port(self):
6568
return default_port
6669

6770
def world_size(self):
68-
return self._world_size
71+
return None
6972

7073
def local_rank(self):
7174
return int(os.environ['SLURM_LOCALID'])

pytorch_lightning/plugins/environments/torchelastic_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class TorchElasticEnvironment(ClusterEnvironment):
2424
def __init__(self):
2525
super().__init__()
2626

27+
def spawns_children(self) -> bool:
28+
return True
29+
2730
def master_address(self):
2831
if "MASTER_ADDR" not in os.environ:
2932
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ def backward(
9292
closure_loss = closure_loss.detach()
9393
return closure_loss
9494

95+
def pre_optimizer_step(
96+
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs
97+
) -> bool:
98+
"""Hook to do something before each optimizer step."""
99+
# Apex: Amp does not support closure use with optimizers
100+
closure()
101+
optimizer.step()
102+
return False
103+
95104
def configure_apex(
96105
self,
97106
amp: object,

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
from pytorch_lightning.distributed import LightningDistributed
2828
from pytorch_lightning.overrides import LightningDistributedModule
2929
from pytorch_lightning.overrides.distributed import prepare_for_backward
30-
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
3130
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3231
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
33-
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
32+
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _PYTORCH_GREATER_EQUAL_1_7_0, rank_zero_warn
3433
from pytorch_lightning.utilities.distributed import (
3534
find_free_network_port,
3635
rank_zero_only,
@@ -88,8 +87,7 @@ def setup(self, model):
8887
self._model = model
8988

9089
# start the other scripts
91-
# TODO: refactor and let generic cluster env hold the information about who spawns the processes
92-
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
90+
if not self.cluster_environment.spawns_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
9391
self._call_children_scripts()
9492

9593
# set the task idx
@@ -103,15 +101,12 @@ def _call_children_scripts(self):
103101
self._has_spawned_children = True
104102

105103
# DDP Environment variables
106-
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
107-
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))
104+
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
105+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
108106

109107
# allow the user to pass the node rank
110-
node_rank = "0"
111-
node_rank = os.environ.get("NODE_RANK", node_rank)
112-
node_rank = os.environ.get("GROUP_RANK", node_rank)
113-
os.environ["NODE_RANK"] = node_rank
114-
os.environ["LOCAL_RANK"] = "0"
108+
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
109+
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
115110

116111
# when user is using hydra find the absolute path
117112
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
@@ -205,7 +200,6 @@ def determine_ddp_device_ids(self):
205200
return [self.root_device.index]
206201

207202
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
208-
# TODO: From where to get cluster environment?
209203
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
210204
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
211205
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def distributed_sampler_kwargs(self):
8282
def setup(self, model):
8383
self._model = model
8484

85-
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))
85+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
8686

8787
# pass in a state q
8888
smp = mp.get_context("spawn")

0 commit comments

Comments
 (0)