2727from pytorch_lightning .distributed import LightningDistributed
2828from pytorch_lightning .overrides import LightningDistributedModule
2929from pytorch_lightning .overrides .distributed import prepare_for_backward
30- from pytorch_lightning .plugins .environments import SLURMEnvironment , TorchElasticEnvironment
3130from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
3231from pytorch_lightning .plugins .training_type .parallel import ParallelPlugin
33- from pytorch_lightning .utilities import _HYDRA_AVAILABLE , _TORCH_GREATER_EQUAL_1_7 , rank_zero_warn
32+ from pytorch_lightning .utilities import _HYDRA_AVAILABLE , _PYTORCH_GREATER_EQUAL_1_7_0 , rank_zero_warn
3433from pytorch_lightning .utilities .distributed import (
3534 find_free_network_port ,
3635 rank_zero_only ,
@@ -88,8 +87,7 @@ def setup(self, model):
8887 self ._model = model
8988
9089 # start the other scripts
91- # TODO: refactor and let generic cluster env hold the information about who spawns the processes
92- if os .environ .get ("PL_IN_DDP_SUBPROCESS" , "0" ) != "1" :
90+ if not self .cluster_environment .spawns_children () and os .environ .get ("PL_IN_DDP_SUBPROCESS" , "0" ) != "1" :
9391 self ._call_children_scripts ()
9492
9593 # set the task idx
@@ -103,15 +101,12 @@ def _call_children_scripts(self):
103101 self ._has_spawned_children = True
104102
105103 # DDP Environment variables
106- os .environ ["MASTER_ADDR" ] = os . environ . get ( "MASTER_ADDR" , "127.0.0.1" )
107- os .environ ["MASTER_PORT" ] = os . environ . get ( "MASTER_PORT" , str ( find_free_network_port () ))
104+ os .environ ["MASTER_ADDR" ] = self . cluster_environment . master_address ( )
105+ os .environ ["MASTER_PORT" ] = str ( self . cluster_environment . master_port ( ))
108106
109107 # allow the user to pass the node rank
110- node_rank = "0"
111- node_rank = os .environ .get ("NODE_RANK" , node_rank )
112- node_rank = os .environ .get ("GROUP_RANK" , node_rank )
113- os .environ ["NODE_RANK" ] = node_rank
114- os .environ ["LOCAL_RANK" ] = "0"
108+ os .environ ["NODE_RANK" ] = str (self .cluster_environment .node_rank ())
109+ os .environ ["LOCAL_RANK" ] = str (self .cluster_environment .local_rank ())
115110
116111 # when user is using hydra find the absolute path
117112 path_lib = os .path .abspath if not _HYDRA_AVAILABLE else to_absolute_path
@@ -205,7 +200,6 @@ def determine_ddp_device_ids(self):
205200 return [self .root_device .index ]
206201
207202 def init_ddp_connection (self , global_rank : int , world_size : int ) -> None :
208- # TODO: From where to get cluster environment?
209203 os .environ ["MASTER_ADDR" ] = str (self .cluster_environment .master_address ())
210204 os .environ ["MASTER_PORT" ] = str (self .cluster_environment .master_port ())
211205 os .environ ["WORLD_SIZE" ] = str (self .cluster_environment .world_size ())
0 commit comments