@@ -37,18 +37,15 @@ def configure_slurm_ddp(self, num_gpu_nodes):
3737 job_name = os .environ ['SLURM_JOB_NAME' ]
3838 if job_name == 'bash' :
3939 self .trainer .is_slurm_managing_tasks = False
40-
40+ # todo: specify the possible exception
4141 except Exception :
4242 # likely not on slurm, so set the slurm managed flag to false
4343 self .trainer .is_slurm_managing_tasks = False
4444
4545 # used for tests only, set this flag to simulate slurm managing a task
46- try :
47- should_fake = int (os .environ ['FAKE_SLURM_MANAGING_TASKS' ])
48- if should_fake :
49- self .trainer .is_slurm_managing_tasks = True
50- except Exception :
51- pass
46+ should_fake = os .environ .get ('FAKE_SLURM_MANAGING_TASKS' )
47+ if should_fake and int (should_fake ):
48+ self .trainer .is_slurm_managing_tasks = True
5249
5350 # notify user the that slurm is managing tasks
5451 if self .trainer .is_slurm_managing_tasks :
@@ -74,6 +71,7 @@ def register_slurm_signal_handlers(self):
7471 job_name = os .environ ['SLURM_JOB_NAME' ]
7572 if job_name != 'bash' :
7673 on_slurm = True
74+ # todo: specify the possible exception
7775 except Exception :
7876 pass
7977
@@ -120,28 +118,27 @@ def connect_ddp(self, global_rank: int, world_size: int) -> None:
120118 """
121119 # use slurm job id for the port number
122120 # guarantees unique ports across jobs from same grid search
123- try :
121+ default_port = os .environ .get ("SLURM_JOB_ID" )
122+ if default_port :
124123 # use the last 4 numbers in the job id as the id
125- default_port = os .environ ["SLURM_JOB_ID" ]
126124 default_port = default_port [- 4 :]
127-
128125 # all ports should be in the 10k+ range
129126 default_port = int (default_port ) + 15000
130-
131- except Exception :
127+ else :
132128 default_port = 12910
133129
134130 # if user gave a port number, use that one instead
135- try :
131+ if "MASTER_PORT" in os . environ :
136132 default_port = os .environ ["MASTER_PORT" ]
137- except Exception :
133+ else :
138134 os .environ ["MASTER_PORT" ] = str (default_port )
139135 log .debug (f"MASTER_PORT: { os .environ ['MASTER_PORT' ]} " )
140136
141137 # figure out the root node addr
142- try :
143- root_node = os .environ ["SLURM_NODELIST" ].split (" " )[0 ]
144- except Exception :
138+ root_node = os .environ .get ("SLURM_NODELIST" )
139+ if root_node :
140+ root_node = root_node .split (" " )[0 ]
141+ else :
145142 root_node = "127.0.0.1"
146143
147144 root_node = self .trainer .slurm_connector .resolve_root_node_address (root_node )
0 commit comments