Skip to content

Commit 9575835

Browse files
authored
mark todo exceptions (#5320)
* mark todo exceptions * . * . * . * . * . * . * . * . * try * .
1 parent af833f6 commit 9575835

File tree

20 files changed

+64
-49
lines changed

20 files changed

+64
-49
lines changed

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _call_children_scripts(self):
100100
command = sys.argv
101101
try:
102102
full_path = path_lib(command[0])
103+
# todo: specify the possible exception
103104
except Exception:
104105
full_path = abspath(command[0])
105106

pytorch_lightning/cluster_environments/slurm_environment.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def __init__(self):
2525

2626
def master_address(self):
2727
# figure out the root node addr
28-
try:
29-
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
30-
except Exception:
28+
slurm_nodelist = os.environ.get("SLURM_NODELIST")
29+
if slurm_nodelist:
30+
root_node = slurm_nodelist.split(" ")[0]
31+
else:
3132
root_node = "127.0.0.1"
3233

3334
root_node = self._resolve_root_node_address(root_node)
@@ -40,24 +41,22 @@ def master_port(self):
4041
# SLURM JOB = PORT number
4142
# -----------------------
4243
# this way every process knows what port to use
43-
try:
44+
default_port = os.environ.get("SLURM_JOB_ID")
45+
if default_port:
4446
# use the last 4 numbers in the job id as the id
45-
default_port = os.environ["SLURM_JOB_ID"]
4647
default_port = default_port[-4:]
47-
4848
# all ports should be in the 10k+ range
4949
default_port = int(default_port) + 15000
50-
51-
except Exception:
50+
else:
5251
default_port = 12910
5352

5453
# -----------------------
5554
# PORT NUMBER = MASTER_PORT
5655
# -----------------------
5756
# in case the user passed it in
58-
try:
57+
if "MASTER_PORT" in os.environ:
5958
default_port = os.environ["MASTER_PORT"]
60-
except Exception:
59+
else:
6160
os.environ["MASTER_PORT"] = str(default_port)
6261

6362
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,6 +1701,7 @@ def __get_hparams_assignment_variable(self):
17011701
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
17021702
if ".hparams=" in line:
17031703
return line.split("=")[1]
1704+
# todo: specify the possible exception
17041705
except Exception:
17051706
return "hparams"
17061707

pytorch_lightning/core/step_result.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def reduce_on_epoch_end(cls, outputs):
527527
result[k] = torch.tensor(result[k]).float()
528528
try:
529529
reduced_val = weighted_mean(result[k], batch_sizes)
530+
# todo: specify the expected Exceptions to come
530531
except Exception:
531532
reduced_val = torch.mean(result[k])
532533
else:

pytorch_lightning/loggers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def _sanitize_callable(val):
200200
if isinstance(_val, Callable):
201201
return val.__name__
202202
return _val
203+
# todo: specify the possible exception
203204
except Exception:
204205
return getattr(val, "__name__", None)
205206
return val

pytorch_lightning/loggers/tensorboard.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
187187
else:
188188
try:
189189
self.experiment.add_scalar(k, v, step)
190-
except Exception as e:
190+
# todo: specify the possible exception
191+
except Exception as ex:
191192
m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.'
192-
type(e)(e.message + m)
193+
type(ex)(ex.message + m)
193194

194195
@rank_zero_only
195196
def log_graph(self, model: LightningModule, input_array=None):

pytorch_lightning/overrides/data_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def _worker(i, module, input, kwargs, device=None):
284284

285285
with lock:
286286
results[i] = output
287+
# todo: specify the possible exception
287288
except Exception as ex:
288289
with lock:
289290
results[i] = ex

pytorch_lightning/trainer/connectors/slurm_connector.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,15 @@ def configure_slurm_ddp(self, num_gpu_nodes):
3737
job_name = os.environ['SLURM_JOB_NAME']
3838
if job_name == 'bash':
3939
self.trainer.is_slurm_managing_tasks = False
40-
40+
# todo: specify the possible exception
4141
except Exception:
4242
# likely not on slurm, so set the slurm managed flag to false
4343
self.trainer.is_slurm_managing_tasks = False
4444

4545
# used for tests only, set this flag to simulate slurm managing a task
46-
try:
47-
should_fake = int(os.environ['FAKE_SLURM_MANAGING_TASKS'])
48-
if should_fake:
49-
self.trainer.is_slurm_managing_tasks = True
50-
except Exception:
51-
pass
46+
should_fake = os.environ.get('FAKE_SLURM_MANAGING_TASKS')
47+
if should_fake and int(should_fake):
48+
self.trainer.is_slurm_managing_tasks = True
5249

5350
# notify user the that slurm is managing tasks
5451
if self.trainer.is_slurm_managing_tasks:
@@ -74,6 +71,7 @@ def register_slurm_signal_handlers(self):
7471
job_name = os.environ['SLURM_JOB_NAME']
7572
if job_name != 'bash':
7673
on_slurm = True
74+
# todo: specify the possible exception
7775
except Exception:
7876
pass
7977

@@ -120,28 +118,27 @@ def connect_ddp(self, global_rank: int, world_size: int) -> None:
120118
"""
121119
# use slurm job id for the port number
122120
# guarantees unique ports across jobs from same grid search
123-
try:
121+
default_port = os.environ.get("SLURM_JOB_ID")
122+
if default_port:
124123
# use the last 4 numbers in the job id as the id
125-
default_port = os.environ["SLURM_JOB_ID"]
126124
default_port = default_port[-4:]
127-
128125
# all ports should be in the 10k+ range
129126
default_port = int(default_port) + 15000
130-
131-
except Exception:
127+
else:
132128
default_port = 12910
133129

134130
# if user gave a port number, use that one instead
135-
try:
131+
if "MASTER_PORT" in os.environ:
136132
default_port = os.environ["MASTER_PORT"]
137-
except Exception:
133+
else:
138134
os.environ["MASTER_PORT"] = str(default_port)
139135
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
140136

141137
# figure out the root node addr
142-
try:
143-
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
144-
except Exception:
138+
root_node = os.environ.get("SLURM_NODELIST")
139+
if root_node:
140+
root_node = root_node.split(" ")[0]
141+
else:
145142
root_node = "127.0.0.1"
146143

147144
root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node)

pytorch_lightning/trainer/logging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def process_dict_result(self, output, train=False):
112112
progress_output = self.reduce_distributed_output(progress_output, num_gpus)
113113

114114
progress_bar_metrics = progress_output
115+
# todo: specify the possible exception
115116
except Exception:
116117
progress_bar_metrics = {}
117118

@@ -128,6 +129,7 @@ def process_dict_result(self, output, train=False):
128129
log_output = self.reduce_distributed_output(log_output, num_gpus)
129130

130131
log_metrics = log_output
132+
# todo: specify the possible exception
131133
except Exception:
132134
log_metrics = {}
133135

@@ -140,6 +142,7 @@ def process_dict_result(self, output, train=False):
140142
if train:
141143
try:
142144
loss = output['loss']
145+
# todo: specify the possible exception
143146
except Exception as exp:
144147
if isinstance(output, torch.Tensor):
145148
loss = output

pytorch_lightning/trainer/properties.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,16 @@ def is_global_zero(self) -> bool:
118118

119119
@property
120120
def slurm_job_id(self) -> Optional[int]:
121-
try:
122-
job_id = os.environ['SLURM_JOB_ID']
123-
job_id = int(job_id)
124-
125-
# in interactive mode, don't make logs use the same job id
126-
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
127-
if in_slurm_interactive_mode:
121+
job_id = os.environ.get('SLURM_JOB_ID')
122+
if job_id:
123+
try:
124+
job_id = int(job_id)
125+
except ValueError:
128126
job_id = None
129127

130-
except Exception:
128+
# in interactive mode, don't make logs use the same job id
129+
in_slurm_interactive_mode = os.environ.get('SLURM_JOB_NAME') == 'bash'
130+
if in_slurm_interactive_mode:
131131
job_id = None
132132
return job_id
133133

0 commit comments

Comments
 (0)