diff --git a/src/sagemaker_training/environment.py b/src/sagemaker_training/environment.py index 3482151c..899e9941 100644 --- a/src/sagemaker_training/environment.py +++ b/src/sagemaker_training/environment.py @@ -562,6 +562,9 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters self._master_hostname = list(hosts)[0] self._is_master = current_host == self._master_hostname + mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS) + self._is_modelparallel_enabled = mp_parameters and mp_parameters != "{}" + @property def model_dir(self): # type: () -> str """The directory where models should be saved. @@ -909,6 +912,15 @@ def framework_module(self): # type: () -> str """ return self._framework_module + @property + def is_modelparallel_enabled(self): # type: () -> bool + """Whether SM ModelParallel is enabled. + + Returns: + bool: True if SM ModelParallel is enabled + """ + return self._is_modelparallel_enabled + def write_env_vars(env_vars=None): # type: (dict) -> None """Write the dictionary env_vars in the system, as environment variables. diff --git a/src/sagemaker_training/mpi.py b/src/sagemaker_training/mpi.py index 7c5fd741..93a3a130 100644 --- a/src/sagemaker_training/mpi.py +++ b/src/sagemaker_training/mpi.py @@ -13,7 +13,7 @@ """This module contains functionality related to distributed training using MPI (Message Passing Interface).""" import argparse -import inspect +from inspect import getfile, isclass import logging import os import subprocess @@ -23,11 +23,31 @@ import psutil import gethostname -from sagemaker_training import logging_config, process, timeout +from sagemaker_training import environment, errors, logging_config, process, timeout logger = logging_config.get_logger() logging.getLogger("paramiko").setLevel(logging.INFO) +exception_classes = None +try: + from smdistributed.modelparallel.backend import exceptions + + # list of exceptions SMMP wants training toolkit to catch and log + exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] +except ImportError: + logger.info("No exception classes found in smdistributed.modelparallel.backend") + +try: + from smdistributed.modelparallel.torch import exceptions as torch_exceptions + + # list of torch exceptions SMMP wants training toolkit to catch and log + exception_classes += [x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))] +except ImportError: + logger.info("No torch exception classes found in smdistributed.modelparallel.torch") + +if not exception_classes: + exception_classes = [errors.ExecuteUserScriptError] + class WorkerRunner(process.ProcessRunner): """Runner responsible for preparing MPI distributed training and waiting for MPI @@ -235,12 +255,16 @@ def _create_command(self): "-x", "PATH", "-x", - "LD_PRELOAD=%s" % inspect.getfile(gethostname), + "LD_PRELOAD=%s" % getfile(gethostname), ] command.extend(additional_options) - for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]: + for credential in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ]: if credential in os.environ: command.extend(["-x", credential]) @@ -256,6 +280,49 @@ def _python_command(self): """ return super(MasterRunner, self)._python_command() + ["-m", "mpi4py"] + def run(self, wait=True, capture_error=False): + """Run the process. + + Args: + wait (bool): A boolean indicating whether to wait and check for errors. + Defaults to True. + capture_error (bool): A boolean indicating whether to direct stderr to a stream + that can later be read. Defaults to False. + + Returns: + process (subprocess.Popen): The spawned process. + """ + self._setup() + + cmd = self._create_command() + + logging_config.log_script_invocation(cmd, self._env_vars) + + training_env = environment.Environment() + if wait: + process_spawned = process.check_error( + cmd, + exception_classes + if training_env.is_modelparallel_enabled + else errors.ExecuteUserScriptError, + self._processes_per_host, + capture_error=capture_error, + cwd=environment.code_dir, + ) + else: + _, _, process_spawned = process.create( + cmd, + exception_classes + if training_env.is_modelparallel_enabled + else errors.ExecuteUserScriptError, + self._processes_per_host, + capture_error=capture_error, + cwd=environment.code_dir, + ) + + self._tear_down() + return process_spawned + _SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """ SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster. diff --git a/src/sagemaker_training/params.py b/src/sagemaker_training/params.py index ab3c5eb6..54808c00 100644 --- a/src/sagemaker_training/params.py +++ b/src/sagemaker_training/params.py @@ -62,3 +62,4 @@ SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = ( "sagemaker_distributed_dataparallel_custom_mpi_options" ) # type: str +SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS" diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index a3ea2431..492d3d3a 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -17,6 +17,7 @@ import asyncio from asyncio.subprocess import PIPE +from inspect import isclass import os import re import subprocess @@ -38,7 +39,24 @@ _DEFAULT_BUF_SIZE = 1024 * 64 -async def watch(stream, proc_per_host): +def process_error_classes(error_classes): + """Process error classes and return a list of string. + Input could be class, string, or None + + Args: + error_classes (list): List of error classes + + Returns: + error_classes: processed classes + """ + if not error_classes: + return [] + if not isinstance(error_classes, list): + error_classes = [error_classes] + return [error.__name__ if isclass(error) else error for error in error_classes] + + +async def watch(stream, proc_per_host, error_classes=None): """Process the stdout and stderr streams on the fly. Decode the output lines Remove new line characters (if any) @@ -48,10 +66,12 @@ async def watch(stream, proc_per_host): Args: stream: asyncio subprocess PIPE proc_per_host (int): Number of processes per each host + error_classes (list): List of exception classes to watch and raise Returns: output: Filtered stderr """ + error_classes = process_error_classes(error_classes) output = [] buf_size = _DEFAULT_BUF_SIZE start = False @@ -82,20 +102,31 @@ async def watch(stream, proc_per_host): line, ) print(line) - # log only if necessary + # log only if necessary, remove node and rank id for de-duplication err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + # in case error piped to stdout + err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + if start: - if line not in output: - output.append(err_line) + if err_line not in output: + output.append(err_line.strip(" :\n") + "\n") else: - if any(err in err_line for err in _PYTHON_ERRORS_): + if any( + str(err) in err_line + for err in ( + _PYTHON_ERRORS_ + error_classes + if isinstance(error_classes, list) + else [error_classes] + ) + ): + # start logging error message if target exceptions found start = True - output.append(err_line + "\n") + output.append(err_line.strip(" :\n") + "\n") return " ".join(output) -async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): +async def run_async(cmd, processes_per_host, env, cwd, stderr, error_classes=None, **kwargs): """Method responsible for launching asyncio subprocess shell Use asyncio gather to collect processed stdout and stderr @@ -105,6 +136,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): env: os.environ cwd (str): The location from which to run the command (default: None). If None, this defaults to the ``code_dir`` of the environment. + error_classes (list): List of exception classes to watch and raise **kwargs: Extra arguments that are passed to the asyncio create subprocess constructor. Returns: @@ -113,7 +145,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): asyncio.subprocess.Process: The asyncio process for the given command. Raises: - error_class: If there is an exception raised when creating the process. + ExecuteUserScriptError: If there is an exception raised when creating the process. """ cmd = " ".join(cmd) proc = await asyncio.create_subprocess_shell( @@ -121,7 +153,8 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): ) output = await asyncio.gather( - watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host) + watch(proc.stdout, processes_per_host, error_classes=error_classes), + watch(proc.stderr, processes_per_host, error_classes=error_classes), ) logger.info("Waiting for the process to finish and give a return code.") return_code = await proc.wait() @@ -129,12 +162,20 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): return return_code, output, proc -def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_error=False, **kwargs): +def create( + cmd, + error_classes, + processes_per_host, + cwd=None, + env=None, + capture_error=False, + **kwargs, +): """Spawn a process with asyncio for the given command. Args: cmd (list): The command to be run. - error_class (cls): The class to use when raising an exception. + error_classes (list): List of exception classes to watch and raise. cwd (str): The location from which to run the command (default: None). If None, this defaults to the ``code_dir`` of the environment. env: os.environ @@ -146,7 +187,7 @@ def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_err asyncio.subprocess.Process: The asyncio process for the given command. Raises: - error_class: If there is an exception raised when creating the process. + ExecuteUserScriptError: If there is an exception raised when creating the process. """ try: stderr = PIPE if capture_error else None @@ -157,20 +198,25 @@ def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_err env=env or os.environ, cwd=cwd or environment.code_dir, stderr=stderr, + error_classes=error_classes, **kwargs, ) ) return rc, output, proc except Exception as e: # pylint: disable=broad-except - six.reraise(error_class, error_class(e), sys.exc_info()[2]) + six.reraise( + errors.ExecuteUserScriptError, + errors.ExecuteUserScriptError(e), + sys.exc_info()[2], + ) -def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=True, **kwargs): +def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=True, **kwargs): """Run a commmand, raising an exception if there is an error. Args: cmd ([str]): The command to be run. - error_class (cls): The class to use when raising an exception. + error_classes (list): List of exception classes to watch and raise. processes_per_host (int): Number of processes per host capture_error (bool): Whether or not to include stderr in the exception message (default: True). In either case, @@ -181,32 +227,57 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr subprocess.Popen: The process for the given command. Raises: - error_class: If there is an exception raised when creating the process. + ExecuteUserScriptError: If there is an exception raised when creating the process. """ - + error_classes = process_error_classes(error_classes) if capture_error: return_code, output, process = create( cmd, - error_class, + error_classes, processes_per_host, env=os.environ, cwd=cwd or environment.code_dir, capture_error=True, **kwargs, ) - stderr = output[1] + stderr = " ".join(output) + # remove duplicate while preserve order + stderr = "\n".join(list(dict.fromkeys(stderr.split("\n")))).strip() else: stderr = None # remove extra quotes for subprocess.Popen cmd[-1] = cmd[-1].strip('"') process = subprocess.Popen( - cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs + cmd, + env=os.environ, + cwd=cwd or environment.code_dir, + stderr=stderr, + **kwargs, ) return_code = process.wait() if return_code: extra_info = None if return_code == 137: extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)" + + # throw internal error classes first + internal_errors = [err for err in dir(errors) if isclass(getattr(errors, err))] + error_class = next( + (name for name in error_classes if name in internal_errors), "ExecuteUserScriptError" + ) + error_class = getattr(errors, error_class) + + # only replace ExecuteUserScriptError with custom library errors + if stderr and error_class == errors.ExecuteUserScriptError: + # find the first target error in stderr + error_name = next((str(name) for name in error_classes if str(name) in stderr), False) + if error_name: + error_class = type( + error_name, + (errors._CalledProcessError,), # pylint: disable=protected-access + {}, + ) + raise error_class( cmd=" ".join(cmd) if isinstance(cmd, list) else cmd, return_code=return_code, @@ -259,7 +330,11 @@ def _create_command(self): six.moves.shlex_quote(arg) # pylint: disable=too-many-function-args for arg in self._args ] - return ["/bin/sh", "-c", '"./%s %s"' % (self._user_entry_point, " ".join(args))] + return [ + "/bin/sh", + "-c", + '"./%s %s"' % (self._user_entry_point, " ".join(args)), + ] def _python_command(self): # pylint: disable=no-self-use return [python_executable()] diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index 69498254..de5d08d9 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """Contains functionality related to SM Distributed Data Parallel Training.""" import argparse -import inspect +from inspect import getfile, isclass import json import logging import os @@ -24,9 +24,19 @@ import gethostname from sagemaker_training import environment, errors, logging_config, process, timeout + logger = logging_config.get_logger() logging.getLogger("paramiko").setLevel(logging.INFO) +try: + from smdistributed.dataparallel import exceptions + + # list of exceptions SMDDP wants training toolkit to catch and log + exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] +except ImportError: + logger.info("No exception classes found in smdistributed.dataparallel") + exception_classes = [errors.ExecuteUserScriptError] + class SMDataParallelRunner(process.ProcessRunner): """Prepare SMDataParallel-based distributed training. @@ -163,7 +173,7 @@ def _get_mpirun_command( "-x", "RDMAV_FORK_SAFE=1", "-x", - "LD_PRELOAD=%s" % inspect.getfile(gethostname), + "LD_PRELOAD=%s" % getfile(gethostname), ] mpirun_command.extend(additional_options) @@ -267,7 +277,7 @@ def run(self, wait=True, capture_error=False): if wait: process_spawned = process.check_error( cmd, - errors.ExecuteUserScriptError, + exception_classes, self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, @@ -275,7 +285,7 @@ def run(self, wait=True, capture_error=False): else: process_spawned = process.create( cmd, - errors.ExecuteUserScriptError, + exception_classes, self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index 4b7c4f2d..4c6ca50a 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -212,6 +212,7 @@ def test_env_mapping_properties(training_env): "output_intermediate_dir", "is_master", "master_hostname", + "is_modelparallel_enabled", } diff --git a/test/unit/test_process.py b/test/unit/test_process.py index 3f0c10d2..86b9095e 100644 --- a/test/unit/test_process.py +++ b/test/unit/test_process.py @@ -105,7 +105,7 @@ async def test_watch(event_loop, capsys): expected_stream += ( "[1,mpirank:0,algo-1]:FileNotFoundError: [Errno 2] No such file or directory\n" ) - expected_errmsg = ":FileNotFoundError: [Errno 2] No such file or directory\n" + expected_errmsg = "FileNotFoundError: [Errno 2] No such file or directory\n" stream = asyncio.StreamReader() stream.feed_data(b"[1,10]:This is stdout\n") @@ -119,6 +119,53 @@ async def test_watch(event_loop, capsys): assert output == expected_errmsg +@pytest.mark.asyncio +async def test_watch_custom_error(event_loop, capsys): + num_processes_per_host = 8 + expected_stream = "[1,mpirank:10,algo-2]:This is stdout\n" + expected_stream += "[1,mpirank:10,algo-2]:This is stderr\n" + expected_stream += "[1,mpirank:0,algo-1]:SMDDPNCCLError: unhandled cuda error\n" + expected_errmsg = "SMDDPNCCLError: unhandled cuda error\n" + + stream = asyncio.StreamReader() + stream.feed_data(b"[1,10]:This is stdout\n") + stream.feed_data(b"[1,10]:This is stderr\n") + stream.feed_data(b"[1,0]:SMDDPNCCLError: unhandled cuda error") + stream.feed_eof() + + error_classes = ["SMDDPNCCLError"] + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + captured_stream = capsys.readouterr() + assert captured_stream.out == expected_stream + assert output == expected_errmsg + + # test errors piped in stdout + stream = asyncio.StreamReader() + stream.feed_data(b"[1,0]:SMDDPNCCLError: unhandled cuda error") + stream.feed_eof() + + error_classes = ["SMDDPNCCLError"] + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + assert output == expected_errmsg + + # test single item + stream = asyncio.StreamReader() + stream.feed_data(b"[1,0]:SMDDPNCCLError: unhandled cuda error") + stream.feed_eof() + error_classes = "SMDDPNCCLError" + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + assert output == expected_errmsg + + # test internal error + expected_errmsg = "ImportModuleError: module does not exist\n" + stream = asyncio.StreamReader() + stream.feed_data(b"[1,0]:ImportModuleError: module does not exist") + stream.feed_eof() + error_classes = [errors.ImportModuleError] + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + assert output == expected_errmsg + + @patch("asyncio.run", AsyncMock(side_effect=ValueError("FAIL"))) def test_create_error(): with pytest.raises(errors.ExecuteUserScriptError):