1313"""This module contains functionality related to distributed training using
1414MPI (Message Passing Interface)."""
1515import argparse
16- import inspect
1716import logging
1817import os
1918import subprocess
2423
2524import gethostname
2625from sagemaker_training import environment , errors , logging_config , process , timeout
27- from inspect import isclass
26+ from inspect import getfile , isclass
2827
2928logger = logging_config .get_logger ()
3029logging .getLogger ("paramiko" ).setLevel (logging .INFO )
3130
3231try :
3332 from smdistributed .modelparallel .backend import exceptions
33+
3434 # list of exceptions SMMP wants training toolkit to catch and log
3535 exception_classes = [x for x in dir (exceptions ) if isclass (getattr (exceptions , x ))]
3636except ImportError as e :
3737 logger .info ("No exception classes found in smdistributed.modelparallel" )
3838 exception_classes = []
3939try :
4040 from smdistributed .modelparallel .torch import exceptions as torch_exceptions
41+
4142 # list of torch exceptions SMMP wants training toolkit to catch and log
42- exception_classes += [x for x in dir (torch_exceptions ) if isclass (getattr (torch_exceptions , x ))]
43+ exception_classes += [
44+ x for x in dir (torch_exceptions ) if isclass (getattr (torch_exceptions , x ))
45+ ]
4346except ImportError as e :
4447 logger .info ("No torch exception classes found in smdistributed.modelparallel" )
4548
@@ -49,7 +52,9 @@ class WorkerRunner(process.ProcessRunner):
4952 master execution to finish.
5053 """
5154
52- def __init__ (self , user_entry_point , args , env_vars , processes_per_host , master_hostname ):
55+ def __init__ (
56+ self , user_entry_point , args , env_vars , processes_per_host , master_hostname
57+ ):
5358 """Initialize a WorkerRunner, which is responsible for preparing distributed
5459 training with MPI and waiting for MPI master execution to finish.
5560
@@ -59,7 +64,9 @@ def __init__(self, user_entry_point, args, env_vars, processes_per_host, master_
5964 env_vars (dict(str,str)): A dictionary of environment variables.
6065 master_hostname (str): The master hostname.
6166 """
62- super (WorkerRunner , self ).__init__ (user_entry_point , args , env_vars , processes_per_host )
67+ super (WorkerRunner , self ).__init__ (
68+ user_entry_point , args , env_vars , processes_per_host
69+ )
6370 self ._master_hostname = str (master_hostname )
6471
6572 def run (
@@ -77,7 +84,9 @@ def run(
7784 self ._wait_master_to_start ()
7885 logger .info ("MPI Master online, creating SSH daemon." )
7986
80- logger .info ("Writing environment variables to /etc/environment for the MPI process." )
87+ logger .info (
88+ "Writing environment variables to /etc/environment for the MPI process."
89+ )
8190 _write_env_vars_to_file ()
8291
8392 _start_sshd_daemon ()
@@ -114,7 +123,9 @@ def _wait_orted_process_to_finish(): # type: () -> None
114123def _orted_process (): # pylint: disable=inconsistent-return-statements
115124 """Wait a maximum of 5 minutes for orted process to start."""
116125 for _ in range (5 * 60 ):
117- procs = [p for p in psutil .process_iter (attrs = ["name" ]) if p .info ["name" ] == "orted" ]
126+ procs = [
127+ p for p in psutil .process_iter (attrs = ["name" ]) if p .info ["name" ] == "orted"
128+ ]
118129 if procs :
119130 logger .info ("Process[es]: %s" , procs )
120131 return procs
@@ -158,7 +169,9 @@ def __init__(
158169 3600 seconds (ie. 1 hour).
159170 """
160171
161- super (MasterRunner , self ).__init__ (user_entry_point , args , env_vars , processes_per_host )
172+ super (MasterRunner , self ).__init__ (
173+ user_entry_point , args , env_vars , processes_per_host
174+ )
162175
163176 self ._master_hostname = master_hostname
164177 self ._hosts = hosts
@@ -194,10 +207,14 @@ def _create_command(self):
194207 if self ._processes_per_host == 1 :
195208 host_list = self ._hosts
196209 else :
197- host_list = ["%s:%s" % (host , self ._processes_per_host ) for host in self ._hosts ]
210+ host_list = [
211+ "%s:%s" % (host , self ._processes_per_host ) for host in self ._hosts
212+ ]
198213
199214 msg = "Env Hosts: %s Hosts: %s process_per_hosts: %s num_processes: %s"
200- logger .info (msg , self ._hosts , host_list , self ._processes_per_host , num_processes )
215+ logger .info (
216+ msg , self ._hosts , host_list , self ._processes_per_host , num_processes
217+ )
201218
202219 overridden_known_options , additional_options = _parse_custom_mpi_options (
203220 self ._custom_mpi_options
@@ -250,12 +267,16 @@ def _create_command(self):
250267 "-x" ,
251268 "PATH" ,
252269 "-x" ,
253- "LD_PRELOAD=%s" % inspect . getfile (gethostname ),
270+ "LD_PRELOAD=%s" % getfile (gethostname ),
254271 ]
255272
256273 command .extend (additional_options )
257274
258- for credential in ["AWS_ACCESS_KEY_ID" , "AWS_SECRET_ACCESS_KEY" , "AWS_SESSION_TOKEN" ]:
275+ for credential in [
276+ "AWS_ACCESS_KEY_ID" ,
277+ "AWS_SECRET_ACCESS_KEY" ,
278+ "AWS_SESSION_TOKEN" ,
279+ ]:
259280 if credential in os .environ :
260281 command .extend (["-x" , credential ])
261282
@@ -293,15 +314,19 @@ def run(self, wait=True, capture_error=False):
293314 if wait :
294315 process_spawned = process .check_error (
295316 cmd ,
296- exception_classes if training_env .is_modelparallel_enabled else errors .ExecuteUserScriptError ,
317+ exception_classes
318+ if training_env .is_modelparallel_enabled
319+ else errors .ExecuteUserScriptError ,
297320 self ._processes_per_host ,
298321 capture_error = capture_error ,
299322 cwd = environment .code_dir ,
300323 )
301324 else :
302325 _ , _ , process_spawned = process .create (
303326 cmd ,
304- exception_classes if training_env .is_modelparallel_enabled else errors .ExecuteUserScriptError ,
327+ exception_classes
328+ if training_env .is_modelparallel_enabled
329+ else errors .ExecuteUserScriptError ,
305330 self ._processes_per_host ,
306331 capture_error = capture_error ,
307332 cwd = environment .code_dir ,
@@ -311,7 +336,6 @@ def run(self, wait=True, capture_error=False):
311336 return process_spawned
312337
313338
314-
315339_SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """
316340SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster.
317341
0 commit comments