Skip to content

Commit 3a1b51c

Browse files
committed
fix style of new commit
1 parent 650f30a commit 3a1b51c

File tree

3 files changed

+48
-17
lines changed

3 files changed

+48
-17
lines changed

src/sagemaker_training/mpi.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"""This module contains functionality related to distributed training using
1414
MPI (Message Passing Interface)."""
1515
import argparse
16-
import inspect
1716
import logging
1817
import os
1918
import subprocess
@@ -24,22 +23,26 @@
2423

2524
import gethostname
2625
from sagemaker_training import environment, errors, logging_config, process, timeout
27-
from inspect import isclass
26+
from inspect import getfile, isclass
2827

2928
logger = logging_config.get_logger()
3029
logging.getLogger("paramiko").setLevel(logging.INFO)
3130

3231
try:
3332
from smdistributed.modelparallel.backend import exceptions
33+
3434
# list of exceptions SMMP wants training toolkit to catch and log
3535
exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))]
3636
except ImportError as e:
3737
logger.info("No exception classes found in smdistributed.modelparallel")
3838
exception_classes = []
3939
try:
4040
from smdistributed.modelparallel.torch import exceptions as torch_exceptions
41+
4142
# list of torch exceptions SMMP wants training toolkit to catch and log
42-
exception_classes += [x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))]
43+
exception_classes += [
44+
x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))
45+
]
4346
except ImportError as e:
4447
logger.info("No torch exception classes found in smdistributed.modelparallel")
4548

@@ -49,7 +52,9 @@ class WorkerRunner(process.ProcessRunner):
4952
master execution to finish.
5053
"""
5154

52-
def __init__(self, user_entry_point, args, env_vars, processes_per_host, master_hostname):
55+
def __init__(
56+
self, user_entry_point, args, env_vars, processes_per_host, master_hostname
57+
):
5358
"""Initialize a WorkerRunner, which is responsible for preparing distributed
5459
training with MPI and waiting for MPI master execution to finish.
5560
@@ -59,7 +64,9 @@ def __init__(self, user_entry_point, args, env_vars, processes_per_host, master_
5964
env_vars (dict(str,str)): A dictionary of environment variables.
6065
master_hostname (str): The master hostname.
6166
"""
62-
super(WorkerRunner, self).__init__(user_entry_point, args, env_vars, processes_per_host)
67+
super(WorkerRunner, self).__init__(
68+
user_entry_point, args, env_vars, processes_per_host
69+
)
6370
self._master_hostname = str(master_hostname)
6471

6572
def run(
@@ -77,7 +84,9 @@ def run(
7784
self._wait_master_to_start()
7885
logger.info("MPI Master online, creating SSH daemon.")
7986

80-
logger.info("Writing environment variables to /etc/environment for the MPI process.")
87+
logger.info(
88+
"Writing environment variables to /etc/environment for the MPI process."
89+
)
8190
_write_env_vars_to_file()
8291

8392
_start_sshd_daemon()
@@ -114,7 +123,9 @@ def _wait_orted_process_to_finish(): # type: () -> None
114123
def _orted_process(): # pylint: disable=inconsistent-return-statements
115124
"""Wait a maximum of 5 minutes for orted process to start."""
116125
for _ in range(5 * 60):
117-
procs = [p for p in psutil.process_iter(attrs=["name"]) if p.info["name"] == "orted"]
126+
procs = [
127+
p for p in psutil.process_iter(attrs=["name"]) if p.info["name"] == "orted"
128+
]
118129
if procs:
119130
logger.info("Process[es]: %s", procs)
120131
return procs
@@ -158,7 +169,9 @@ def __init__(
158169
3600 seconds (ie. 1 hour).
159170
"""
160171

161-
super(MasterRunner, self).__init__(user_entry_point, args, env_vars, processes_per_host)
172+
super(MasterRunner, self).__init__(
173+
user_entry_point, args, env_vars, processes_per_host
174+
)
162175

163176
self._master_hostname = master_hostname
164177
self._hosts = hosts
@@ -194,10 +207,14 @@ def _create_command(self):
194207
if self._processes_per_host == 1:
195208
host_list = self._hosts
196209
else:
197-
host_list = ["%s:%s" % (host, self._processes_per_host) for host in self._hosts]
210+
host_list = [
211+
"%s:%s" % (host, self._processes_per_host) for host in self._hosts
212+
]
198213

199214
msg = "Env Hosts: %s Hosts: %s process_per_hosts: %s num_processes: %s"
200-
logger.info(msg, self._hosts, host_list, self._processes_per_host, num_processes)
215+
logger.info(
216+
msg, self._hosts, host_list, self._processes_per_host, num_processes
217+
)
201218

202219
overridden_known_options, additional_options = _parse_custom_mpi_options(
203220
self._custom_mpi_options
@@ -250,12 +267,16 @@ def _create_command(self):
250267
"-x",
251268
"PATH",
252269
"-x",
253-
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
270+
"LD_PRELOAD=%s" % getfile(gethostname),
254271
]
255272

256273
command.extend(additional_options)
257274

258-
for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]:
275+
for credential in [
276+
"AWS_ACCESS_KEY_ID",
277+
"AWS_SECRET_ACCESS_KEY",
278+
"AWS_SESSION_TOKEN",
279+
]:
259280
if credential in os.environ:
260281
command.extend(["-x", credential])
261282

@@ -293,15 +314,19 @@ def run(self, wait=True, capture_error=False):
293314
if wait:
294315
process_spawned = process.check_error(
295316
cmd,
296-
exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError,
317+
exception_classes
318+
if training_env.is_modelparallel_enabled
319+
else errors.ExecuteUserScriptError,
297320
self._processes_per_host,
298321
capture_error=capture_error,
299322
cwd=environment.code_dir,
300323
)
301324
else:
302325
_, _, process_spawned = process.create(
303326
cmd,
304-
exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError,
327+
exception_classes
328+
if training_env.is_modelparallel_enabled
329+
else errors.ExecuteUserScriptError,
305330
self._processes_per_host,
306331
capture_error=capture_error,
307332
cwd=environment.code_dir,
@@ -311,7 +336,6 @@ def run(self, wait=True, capture_error=False):
311336
return process_spawned
312337

313338

314-
315339
_SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """
316340
SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster.
317341

src/sagemaker_training/params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@
5959
SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = (
6060
"sagemaker_distributed_dataparallel_custom_mpi_options"
6161
) # type: str
62-
SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS"
62+
SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS"

src/sagemaker_training/process.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,14 @@ async def watch(stream, error_classes, proc_per_host):
9090
if err_line not in output:
9191
output.append(err_line.strip(" :\n") + "\n")
9292
else:
93-
if any(str(err) in err_line for err in (_PYTHON_ERRORS_ + error_classes if type(error_classes) == list else [error_classes])):
93+
if any(
94+
str(err) in err_line
95+
for err in (
96+
_PYTHON_ERRORS_ + error_classes
97+
if isinstance(error_classes, list)
98+
else [error_classes]
99+
)
100+
):
94101
# start logging error message if target exceptions found
95102
start = True
96103
output.append(err_line.strip(" :\n") + "\n")

0 commit comments

Comments
 (0)