From cfaff0ff0481da2f19716d142fbc1294313f9d8d Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 13 Apr 2022 11:24:59 -0700 Subject: [PATCH 1/7] log smddp exceptions --- src/sagemaker_training/process.py | 3 +- .../smdataparallel_exceptions.py | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 src/sagemaker_training/smdataparallel_exceptions.py diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index a3ea2431..f8480a1b 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -30,6 +30,7 @@ environment, errors, logging_config, + SMDDP_EXCEPTIONS ) logger = logging_config.get_logger() @@ -88,7 +89,7 @@ async def watch(stream, proc_per_host): if line not in output: output.append(err_line) else: - if any(err in err_line for err in _PYTHON_ERRORS_): + if any(err in err_line for err in (_PYTHON_ERRORS_ + SMDDP_EXCEPTIONS)): start = True output.append(err_line + "\n") diff --git a/src/sagemaker_training/smdataparallel_exceptions.py b/src/sagemaker_training/smdataparallel_exceptions.py new file mode 100644 index 00000000..46ddceed --- /dev/null +++ b/src/sagemaker_training/smdataparallel_exceptions.py @@ -0,0 +1,29 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +""" +Custom Exceptions to capture Herring errors in the Telemetry system +""" +SMDDP_EXCEPTIONS = [ + "SMDDPError", + "SMDDPValidationError", + "SMDDPConfigError", + "SMDDPInvalidArgumentError", + "SMDDPUnsupportedError", + "SMDDPRuntimeError", + "SMDDPInitializationError", + "SMDDPLogicError", + "SMDDPTimeoutError", + "SMDDPCUDAError", + "SMDDPNCCLError", + "SMDDPOOMError" +] \ No newline at end of file From 9dc55e082a6edeba142ae6feeaadfa6474f23039 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 15 Apr 2022 02:08:43 -0700 Subject: [PATCH 2/7] update exception class --- src/sagemaker_training/process.py | 45 ++++++++++++------- src/sagemaker_training/smdataparallel.py | 12 ++++- .../smdataparallel_exceptions.py | 29 ------------ 3 files changed, 38 insertions(+), 48 deletions(-) delete mode 100644 src/sagemaker_training/smdataparallel_exceptions.py diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index f8480a1b..1516ca5c 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -30,7 +30,6 @@ environment, errors, logging_config, - SMDDP_EXCEPTIONS ) logger = logging_config.get_logger() @@ -38,8 +37,7 @@ # Default limit of the stream is 2 ** 16 KB, we can increase it to 128KB in subproc call _DEFAULT_BUF_SIZE = 1024 * 64 - -async def watch(stream, proc_per_host): +async def watch(stream, error_classes, proc_per_host): """Process the stdout and stderr streams on the fly. Decode the output lines Remove new line characters (if any) @@ -48,6 +46,7 @@ async def watch(stream, proc_per_host): Args: stream: asyncio subprocess PIPE + error_classes (list): List of exception classes watch and raise proc_per_host (int): Number of processes per each host Returns: @@ -83,25 +82,30 @@ async def watch(stream, proc_per_host): line, ) print(line) - # log only if necessary + # log only if necessary, remove node and rank id for de-duplication err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + # in case error piped to stdout + err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + if start: if line not in output: output.append(err_line) else: - if any(err in err_line for err in (_PYTHON_ERRORS_ + SMDDP_EXCEPTIONS)): + if any(err in err_line for err in (_PYTHON_ERRORS_ + error_classes)): + # start logging error message if target exceptions found start = True output.append(err_line + "\n") return " ".join(output) -async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): +async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **kwargs): """Method responsible for launching asyncio subprocess shell Use asyncio gather to collect processed stdout and stderr Args: cmd (list): The command to be run + error_classes (list): List of exception classes watch and raise processes_per_host (int): Number of processes per host env: os.environ cwd (str): The location from which to run the command (default: None). @@ -114,7 +118,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): asyncio.subprocess.Process: The asyncio process for the given command. Raises: - error_class: If there is an exception raised when creating the process. + ExecuteUserScriptError: If there is an exception raised when creating the process. """ cmd = " ".join(cmd) proc = await asyncio.create_subprocess_shell( @@ -122,7 +126,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): ) output = await asyncio.gather( - watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host) + watch(proc.stdout, error_classes, processes_per_host), watch(proc.stderr, error_classes, processes_per_host) ) logger.info("Waiting for the process to finish and give a return code.") return_code = await proc.wait() @@ -130,12 +134,12 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): return return_code, output, proc -def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_error=False, **kwargs): +def create(cmd, error_classes, processes_per_host, cwd=None, env=None, capture_error=False, **kwargs): """Spawn a process with asyncio for the given command. Args: cmd (list): The command to be run. - error_class (cls): The class to use when raising an exception. + error_classes (list): List of exception classes watch and raise. cwd (str): The location from which to run the command (default: None). If None, this defaults to the ``code_dir`` of the environment. env: os.environ @@ -147,13 +151,14 @@ def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_err asyncio.subprocess.Process: The asyncio process for the given command. Raises: - error_class: If there is an exception raised when creating the process. + ExecuteUserScriptError: If there is an exception raised when creating the process. """ try: stderr = PIPE if capture_error else None rc, output, proc = asyncio.run( run_async( cmd, + error_classes, processes_per_host, env=env or os.environ, cwd=cwd or environment.code_dir, @@ -163,15 +168,15 @@ def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_err ) return rc, output, proc except Exception as e: # pylint: disable=broad-except - six.reraise(error_class, error_class(e), sys.exc_info()[2]) + six.reraise(errors.ExecuteUserScriptError, errors.ExecuteUserScriptError(e), sys.exc_info()[2]) -def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=True, **kwargs): +def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=True, **kwargs): """Run a commmand, raising an exception if there is an error. Args: cmd ([str]): The command to be run. - error_class (cls): The class to use when raising an exception. + error_classes (list): List of exception classes watch and raise. processes_per_host (int): Number of processes per host capture_error (bool): Whether or not to include stderr in the exception message (default: True). In either case, @@ -182,20 +187,20 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr subprocess.Popen: The process for the given command. Raises: - error_class: If there is an exception raised when creating the process. + ExecuteUserScriptError: If there is an exception raised when creating the process. """ if capture_error: return_code, output, process = create( cmd, - error_class, + error_classes, processes_per_host, env=os.environ, cwd=cwd or environment.code_dir, capture_error=True, **kwargs, ) - stderr = output[1] + stderr = " ".join(output) else: stderr = None # remove extra quotes for subprocess.Popen @@ -208,6 +213,12 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr extra_info = None if return_code == 137: extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)" + error_class = errors.ExecuteUserScriptError + for error_name in error_classes: + if error_name in stderr: + error_class = type(error_class_str, (errors._CalledProcessError,), {}) + break + raise error_class( cmd=" ".join(cmd) if isinstance(cmd, list) else cmd, return_code=return_code, diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index 69498254..eb862f85 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -23,6 +23,14 @@ import gethostname from sagemaker_training import environment, errors, logging_config, process, timeout +from inspect import isclass +try: + from smdistributed.dataparallel import exceptions + # list of exceptions SMDDP wants training toolkit to catch and log + exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] +except ImportError as e: + logger.info("No exception classes found in smdistributed.dataparallel") + exception_classes = [] logger = logging_config.get_logger() logging.getLogger("paramiko").setLevel(logging.INFO) @@ -267,7 +275,7 @@ def run(self, wait=True, capture_error=False): if wait: process_spawned = process.check_error( cmd, - errors.ExecuteUserScriptError, + exception_classes, self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, @@ -275,7 +283,7 @@ def run(self, wait=True, capture_error=False): else: process_spawned = process.create( cmd, - errors.ExecuteUserScriptError, + exception_classes, self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, diff --git a/src/sagemaker_training/smdataparallel_exceptions.py b/src/sagemaker_training/smdataparallel_exceptions.py deleted file mode 100644 index 46ddceed..00000000 --- a/src/sagemaker_training/smdataparallel_exceptions.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -""" -Custom Exceptions to capture Herring errors in the Telemetry system -""" -SMDDP_EXCEPTIONS = [ - "SMDDPError", - "SMDDPValidationError", - "SMDDPConfigError", - "SMDDPInvalidArgumentError", - "SMDDPUnsupportedError", - "SMDDPRuntimeError", - "SMDDPInitializationError", - "SMDDPLogicError", - "SMDDPTimeoutError", - "SMDDPCUDAError", - "SMDDPNCCLError", - "SMDDPOOMError" -] \ No newline at end of file From f1dfc040833fcc0909032f18c1d55a90189169d8 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 15 Apr 2022 18:10:16 +0000 Subject: [PATCH 3/7] clean up error msg --- src/sagemaker_training/process.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index 1516ca5c..ae270ba3 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -83,18 +83,18 @@ async def watch(stream, error_classes, proc_per_host): ) print(line) # log only if necessary, remove node and rank id for de-duplication - err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + err_line = err_line[err_line.find("") + 8 :] # in case error piped to stdout - err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + err_line = err_line[err_line.find("") + 8 :] if start: - if line not in output: - output.append(err_line) + if err_line not in output: + output.append(err_line.strip(" :\n") + "\n") else: if any(err in err_line for err in (_PYTHON_ERRORS_ + error_classes)): # start logging error message if target exceptions found start = True - output.append(err_line + "\n") + output.append(err_line.strip(" :\n") + "\n") return " ".join(output) @@ -201,6 +201,8 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error= **kwargs, ) stderr = " ".join(output) + # remove duplicate while preserve order + stderr = "\n".join(list(dict.fromkeys(stderr.split("\n")))).strip() else: stderr = None # remove extra quotes for subprocess.Popen @@ -216,7 +218,7 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error= error_class = errors.ExecuteUserScriptError for error_name in error_classes: if error_name in stderr: - error_class = type(error_class_str, (errors._CalledProcessError,), {}) + error_class = type(error_name, (errors._CalledProcessError,), {}) break raise error_class( From cfe3b6c2ed8c3b04750009077cf87d00e3b901ab Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 17 May 2022 10:45:49 -0700 Subject: [PATCH 4/7] address comments --- src/sagemaker_training/process.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index ae270ba3..c9d39bb9 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -46,7 +46,7 @@ async def watch(stream, error_classes, proc_per_host): Args: stream: asyncio subprocess PIPE - error_classes (list): List of exception classes watch and raise + error_classes (list): List of exception classes to watch and raise proc_per_host (int): Number of processes per each host Returns: @@ -83,9 +83,9 @@ async def watch(stream, error_classes, proc_per_host): ) print(line) # log only if necessary, remove node and rank id for de-duplication - err_line = err_line[err_line.find("") + 8 :] + err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) # in case error piped to stdout - err_line = err_line[err_line.find("") + 8 :] + err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) if start: if err_line not in output: @@ -105,7 +105,7 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, ** Args: cmd (list): The command to be run - error_classes (list): List of exception classes watch and raise + error_classes (list): List of exception classes to watch and raise processes_per_host (int): Number of processes per host env: os.environ cwd (str): The location from which to run the command (default: None). @@ -139,7 +139,7 @@ def create(cmd, error_classes, processes_per_host, cwd=None, env=None, capture_e Args: cmd (list): The command to be run. - error_classes (list): List of exception classes watch and raise. + error_classes (list): List of exception classes to watch and raise. cwd (str): The location from which to run the command (default: None). If None, this defaults to the ``code_dir`` of the environment. env: os.environ @@ -176,7 +176,7 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error= Args: cmd ([str]): The command to be run. - error_classes (list): List of exception classes watch and raise. + error_classes (list): List of exception classes to watch and raise. processes_per_host (int): Number of processes per host capture_error (bool): Whether or not to include stderr in the exception message (default: True). In either case, @@ -215,7 +215,9 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error= extra_info = None if return_code == 137: extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)" + # default error class will be user script error error_class = errors.ExecuteUserScriptError + # use first found target error class if available for error_name in error_classes: if error_name in stderr: error_class = type(error_name, (errors._CalledProcessError,), {}) From 9b7d110677ee5cf7e42e3d0abc5d111050acc47d Mon Sep 17 00:00:00 2001 From: haohanchen-yagao Date: Wed, 18 May 2022 14:14:27 -0700 Subject: [PATCH 5/7] Add Error Categorization for SMMP --- src/sagemaker_training/environment.py | 12 ++++++ src/sagemaker_training/mpi.py | 52 +++++++++++++++++++++++- src/sagemaker_training/params.py | 1 + src/sagemaker_training/smdataparallel.py | 4 +- 4 files changed, 66 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_training/environment.py b/src/sagemaker_training/environment.py index 3482151c..9a32d158 100644 --- a/src/sagemaker_training/environment.py +++ b/src/sagemaker_training/environment.py @@ -562,6 +562,10 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters self._master_hostname = list(hosts)[0] self._is_master = current_host == self._master_hostname + mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS) + self._is_modelparallel_enabled = mp_parameters and mp_parameters != '{}' + + @property def model_dir(self): # type: () -> str """The directory where models should be saved. @@ -909,6 +913,14 @@ def framework_module(self): # type: () -> str """ return self._framework_module + @property + def is_modelparallel_enabled(self): # type: () -> bool + """Whether SM ModelParallel is enabled. + + Returns: + bool: True if SM ModelParallel is enabled + """ + return self._is_modelparallel_enabled def write_env_vars(env_vars=None): # type: (dict) -> None """Write the dictionary env_vars in the system, as environment variables. diff --git a/src/sagemaker_training/mpi.py b/src/sagemaker_training/mpi.py index 7c5fd741..d444864a 100644 --- a/src/sagemaker_training/mpi.py +++ b/src/sagemaker_training/mpi.py @@ -23,9 +23,17 @@ import psutil import gethostname -from sagemaker_training import logging_config, process, timeout - +from sagemaker_training import environment, errors, logging_config, process, timeout +from inspect import isclass logger = logging_config.get_logger() +try: + from smdistributed.modelparallel.backend import exceptions + # list of exceptions SMDDP wants training toolkit to catch and log + exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] +except ImportError as e: + logger.info("No exception classes found in smdistributed.modelparallel") + exception_classes = [] + logging.getLogger("paramiko").setLevel(logging.INFO) @@ -256,6 +264,46 @@ def _python_command(self): """ return super(MasterRunner, self)._python_command() + ["-m", "mpi4py"] + def run(self, wait=True, capture_error=False): + """Run the process. + + Args: + wait (bool): A boolean indicating whether to wait and check for errors. + Defaults to True. + capture_error (bool): A boolean indicating whether to direct stderr to a stream + that can later be read. Defaults to False. + + Returns: + process (subprocess.Popen): The spawned process. + """ + self._setup() + + cmd = self._create_command() + + logging_config.log_script_invocation(cmd, self._env_vars) + + training_env = environment.Environment() + if wait: + process_spawned = process.check_error( + cmd, + exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError, + self._processes_per_host, + capture_error=capture_error, + cwd=environment.code_dir, + ) + else: + _, _, process_spawned = process.create( + cmd, + exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError, + self._processes_per_host, + capture_error=capture_error, + cwd=environment.code_dir, + ) + + self._tear_down() + return process_spawned + + _SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """ SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster. diff --git a/src/sagemaker_training/params.py b/src/sagemaker_training/params.py index c7eb8d1a..8aedeeb1 100644 --- a/src/sagemaker_training/params.py +++ b/src/sagemaker_training/params.py @@ -59,3 +59,4 @@ SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = ( "sagemaker_distributed_dataparallel_custom_mpi_options" ) # type: str +SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS" diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index eb862f85..b414d3fd 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -24,6 +24,8 @@ import gethostname from sagemaker_training import environment, errors, logging_config, process, timeout from inspect import isclass + +logger = logging_config.get_logger() try: from smdistributed.dataparallel import exceptions # list of exceptions SMDDP wants training toolkit to catch and log @@ -32,7 +34,7 @@ logger.info("No exception classes found in smdistributed.dataparallel") exception_classes = [] -logger = logging_config.get_logger() + logging.getLogger("paramiko").setLevel(logging.INFO) From eee0ba00ad0dbfae4a3030d04b1525e560712e5b Mon Sep 17 00:00:00 2001 From: haohanchen-yagao Date: Thu, 19 May 2022 08:38:38 -0700 Subject: [PATCH 6/7] add pytorch errors for SMMP && minor fixes --- src/sagemaker_training/environment.py | 1 - src/sagemaker_training/mpi.py | 10 +++++++++- src/sagemaker_training/process.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_training/environment.py b/src/sagemaker_training/environment.py index 9a32d158..e04fa75b 100644 --- a/src/sagemaker_training/environment.py +++ b/src/sagemaker_training/environment.py @@ -565,7 +565,6 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS) self._is_modelparallel_enabled = mp_parameters and mp_parameters != '{}' - @property def model_dir(self): # type: () -> str """The directory where models should be saved. diff --git a/src/sagemaker_training/mpi.py b/src/sagemaker_training/mpi.py index d444864a..68fcf711 100644 --- a/src/sagemaker_training/mpi.py +++ b/src/sagemaker_training/mpi.py @@ -25,14 +25,22 @@ import gethostname from sagemaker_training import environment, errors, logging_config, process, timeout from inspect import isclass + logger = logging_config.get_logger() try: from smdistributed.modelparallel.backend import exceptions - # list of exceptions SMDDP wants training toolkit to catch and log + # list of exceptions SMMP wants training toolkit to catch and log exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] except ImportError as e: logger.info("No exception classes found in smdistributed.modelparallel") exception_classes = [] +try: + from smdistributed.modelparallel.torch import exceptions as torch_exceptions + # list of torch exceptions SMMP wants training toolkit to catch and log + exception_classes += [x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))] +except ImportError as e: + logger.info("No torch exception classes found in smdistributed.modelparallel") + logging.getLogger("paramiko").setLevel(logging.INFO) diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index c9d39bb9..70f6c547 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -91,7 +91,7 @@ async def watch(stream, error_classes, proc_per_host): if err_line not in output: output.append(err_line.strip(" :\n") + "\n") else: - if any(err in err_line for err in (_PYTHON_ERRORS_ + error_classes)): + if any(str(err) in err_line for err in (_PYTHON_ERRORS_ + error_classes if type(error_classes) == list else [error_classes])): # start logging error message if target exceptions found start = True output.append(err_line.strip(" :\n") + "\n") From 8a52e678abd21d6c9a17fe5ce3b653d373bbd96e Mon Sep 17 00:00:00 2001 From: haohanchen-yagao Date: Thu, 19 May 2022 09:37:48 -0700 Subject: [PATCH 7/7] feature: allow framework libraries to supply exceptions to track and report as failure reason. Added support for SMDDP and SMMP custom exceptions. Include custom exception as error class and de-duplicated stack trace as error message. Added tests for wacthing single, list of exceptions and also support existing internal exceptions. --- src/sagemaker_training/environment.py | 3 +- src/sagemaker_training/mpi.py | 39 ++++++---- src/sagemaker_training/params.py | 2 +- src/sagemaker_training/process.py | 99 +++++++++++++++++++----- src/sagemaker_training/smdataparallel.py | 16 ++-- test/unit/test_environment.py | 1 + test/unit/test_process.py | 49 +++++++++++- 7 files changed, 164 insertions(+), 45 deletions(-) diff --git a/src/sagemaker_training/environment.py b/src/sagemaker_training/environment.py index e04fa75b..899e9941 100644 --- a/src/sagemaker_training/environment.py +++ b/src/sagemaker_training/environment.py @@ -563,7 +563,7 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters self._is_master = current_host == self._master_hostname mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS) - self._is_modelparallel_enabled = mp_parameters and mp_parameters != '{}' + self._is_modelparallel_enabled = mp_parameters and mp_parameters != "{}" @property def model_dir(self): # type: () -> str @@ -921,6 +921,7 @@ def is_modelparallel_enabled(self): # type: () -> bool """ return self._is_modelparallel_enabled + def write_env_vars(env_vars=None): # type: (dict) -> None """Write the dictionary env_vars in the system, as environment variables. diff --git a/src/sagemaker_training/mpi.py b/src/sagemaker_training/mpi.py index 68fcf711..93a3a130 100644 --- a/src/sagemaker_training/mpi.py +++ b/src/sagemaker_training/mpi.py @@ -13,7 +13,7 @@ """This module contains functionality related to distributed training using MPI (Message Passing Interface).""" import argparse -import inspect +from inspect import getfile, isclass import logging import os import subprocess @@ -24,25 +24,29 @@ import gethostname from sagemaker_training import environment, errors, logging_config, process, timeout -from inspect import isclass logger = logging_config.get_logger() +logging.getLogger("paramiko").setLevel(logging.INFO) + +exception_classes = None try: from smdistributed.modelparallel.backend import exceptions + # list of exceptions SMMP wants training toolkit to catch and log exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] -except ImportError as e: - logger.info("No exception classes found in smdistributed.modelparallel") - exception_classes = [] +except ImportError: + logger.info("No exception classes found in smdistributed.modelparallel.backend") + try: from smdistributed.modelparallel.torch import exceptions as torch_exceptions + # list of torch exceptions SMMP wants training toolkit to catch and log exception_classes += [x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))] -except ImportError as e: - logger.info("No torch exception classes found in smdistributed.modelparallel") - +except ImportError: + logger.info("No torch exception classes found in smdistributed.modelparallel.torch") -logging.getLogger("paramiko").setLevel(logging.INFO) +if not exception_classes: + exception_classes = [errors.ExecuteUserScriptError] class WorkerRunner(process.ProcessRunner): @@ -251,12 +255,16 @@ def _create_command(self): "-x", "PATH", "-x", - "LD_PRELOAD=%s" % inspect.getfile(gethostname), + "LD_PRELOAD=%s" % getfile(gethostname), ] command.extend(additional_options) - for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]: + for credential in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ]: if credential in os.environ: command.extend(["-x", credential]) @@ -294,7 +302,9 @@ def run(self, wait=True, capture_error=False): if wait: process_spawned = process.check_error( cmd, - exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError, + exception_classes + if training_env.is_modelparallel_enabled + else errors.ExecuteUserScriptError, self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, @@ -302,7 +312,9 @@ def run(self, wait=True, capture_error=False): else: _, _, process_spawned = process.create( cmd, - exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError, + exception_classes + if training_env.is_modelparallel_enabled + else errors.ExecuteUserScriptError, self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, @@ -312,7 +324,6 @@ def run(self, wait=True, capture_error=False): return process_spawned - _SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """ SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster. diff --git a/src/sagemaker_training/params.py b/src/sagemaker_training/params.py index 8aedeeb1..ac356d0b 100644 --- a/src/sagemaker_training/params.py +++ b/src/sagemaker_training/params.py @@ -59,4 +59,4 @@ SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = ( "sagemaker_distributed_dataparallel_custom_mpi_options" ) # type: str -SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS" +SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS" diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index 70f6c547..492d3d3a 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -17,6 +17,7 @@ import asyncio from asyncio.subprocess import PIPE +from inspect import isclass import os import re import subprocess @@ -37,7 +38,25 @@ # Default limit of the stream is 2 ** 16 KB, we can increase it to 128KB in subproc call _DEFAULT_BUF_SIZE = 1024 * 64 -async def watch(stream, error_classes, proc_per_host): + +def process_error_classes(error_classes): + """Process error classes and return a list of string. + Input could be class, string, or None + + Args: + error_classes (list): List of error classes + + Returns: + error_classes: processed classes + """ + if not error_classes: + return [] + if not isinstance(error_classes, list): + error_classes = [error_classes] + return [error.__name__ if isclass(error) else error for error in error_classes] + + +async def watch(stream, proc_per_host, error_classes=None): """Process the stdout and stderr streams on the fly. Decode the output lines Remove new line characters (if any) @@ -46,12 +65,13 @@ async def watch(stream, error_classes, proc_per_host): Args: stream: asyncio subprocess PIPE - error_classes (list): List of exception classes to watch and raise proc_per_host (int): Number of processes per each host + error_classes (list): List of exception classes to watch and raise Returns: output: Filtered stderr """ + error_classes = process_error_classes(error_classes) output = [] buf_size = _DEFAULT_BUF_SIZE start = False @@ -89,9 +109,16 @@ async def watch(stream, error_classes, proc_per_host): if start: if err_line not in output: - output.append(err_line.strip(" :\n") + "\n") + output.append(err_line.strip(" :\n") + "\n") else: - if any(str(err) in err_line for err in (_PYTHON_ERRORS_ + error_classes if type(error_classes) == list else [error_classes])): + if any( + str(err) in err_line + for err in ( + _PYTHON_ERRORS_ + error_classes + if isinstance(error_classes, list) + else [error_classes] + ) + ): # start logging error message if target exceptions found start = True output.append(err_line.strip(" :\n") + "\n") @@ -99,17 +126,17 @@ async def watch(stream, error_classes, proc_per_host): return " ".join(output) -async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **kwargs): +async def run_async(cmd, processes_per_host, env, cwd, stderr, error_classes=None, **kwargs): """Method responsible for launching asyncio subprocess shell Use asyncio gather to collect processed stdout and stderr Args: cmd (list): The command to be run - error_classes (list): List of exception classes to watch and raise processes_per_host (int): Number of processes per host env: os.environ cwd (str): The location from which to run the command (default: None). If None, this defaults to the ``code_dir`` of the environment. + error_classes (list): List of exception classes to watch and raise **kwargs: Extra arguments that are passed to the asyncio create subprocess constructor. Returns: @@ -126,7 +153,8 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, ** ) output = await asyncio.gather( - watch(proc.stdout, error_classes, processes_per_host), watch(proc.stderr, error_classes, processes_per_host) + watch(proc.stdout, processes_per_host, error_classes=error_classes), + watch(proc.stderr, processes_per_host, error_classes=error_classes), ) logger.info("Waiting for the process to finish and give a return code.") return_code = await proc.wait() @@ -134,7 +162,15 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, ** return return_code, output, proc -def create(cmd, error_classes, processes_per_host, cwd=None, env=None, capture_error=False, **kwargs): +def create( + cmd, + error_classes, + processes_per_host, + cwd=None, + env=None, + capture_error=False, + **kwargs, +): """Spawn a process with asyncio for the given command. Args: @@ -158,17 +194,21 @@ def create(cmd, error_classes, processes_per_host, cwd=None, env=None, capture_e rc, output, proc = asyncio.run( run_async( cmd, - error_classes, processes_per_host, env=env or os.environ, cwd=cwd or environment.code_dir, stderr=stderr, + error_classes=error_classes, **kwargs, ) ) return rc, output, proc except Exception as e: # pylint: disable=broad-except - six.reraise(errors.ExecuteUserScriptError, errors.ExecuteUserScriptError(e), sys.exc_info()[2]) + six.reraise( + errors.ExecuteUserScriptError, + errors.ExecuteUserScriptError(e), + sys.exc_info()[2], + ) def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=True, **kwargs): @@ -189,7 +229,7 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error= Raises: ExecuteUserScriptError: If there is an exception raised when creating the process. """ - + error_classes = process_error_classes(error_classes) if capture_error: return_code, output, process = create( cmd, @@ -208,20 +248,35 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error= # remove extra quotes for subprocess.Popen cmd[-1] = cmd[-1].strip('"') process = subprocess.Popen( - cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs + cmd, + env=os.environ, + cwd=cwd or environment.code_dir, + stderr=stderr, + **kwargs, ) return_code = process.wait() if return_code: extra_info = None if return_code == 137: extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)" - # default error class will be user script error - error_class = errors.ExecuteUserScriptError - # use first found target error class if available - for error_name in error_classes: - if error_name in stderr: - error_class = type(error_name, (errors._CalledProcessError,), {}) - break + + # throw internal error classes first + internal_errors = [err for err in dir(errors) if isclass(getattr(errors, err))] + error_class = next( + (name for name in error_classes if name in internal_errors), "ExecuteUserScriptError" + ) + error_class = getattr(errors, error_class) + + # only replace ExecuteUserScriptError with custom library errors + if stderr and error_class == errors.ExecuteUserScriptError: + # find the first target error in stderr + error_name = next((str(name) for name in error_classes if str(name) in stderr), False) + if error_name: + error_class = type( + error_name, + (errors._CalledProcessError,), # pylint: disable=protected-access + {}, + ) raise error_class( cmd=" ".join(cmd) if isinstance(cmd, list) else cmd, @@ -275,7 +330,11 @@ def _create_command(self): six.moves.shlex_quote(arg) # pylint: disable=too-many-function-args for arg in self._args ] - return ["/bin/sh", "-c", '"./%s %s"' % (self._user_entry_point, " ".join(args))] + return [ + "/bin/sh", + "-c", + '"./%s %s"' % (self._user_entry_point, " ".join(args)), + ] def _python_command(self): # pylint: disable=no-self-use return [python_executable()] diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index b414d3fd..de5d08d9 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """Contains functionality related to SM Distributed Data Parallel Training.""" import argparse -import inspect +from inspect import getfile, isclass import json import logging import os @@ -23,19 +23,19 @@ import gethostname from sagemaker_training import environment, errors, logging_config, process, timeout -from inspect import isclass + logger = logging_config.get_logger() +logging.getLogger("paramiko").setLevel(logging.INFO) + try: from smdistributed.dataparallel import exceptions + # list of exceptions SMDDP wants training toolkit to catch and log exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))] -except ImportError as e: +except ImportError: logger.info("No exception classes found in smdistributed.dataparallel") - exception_classes = [] - - -logging.getLogger("paramiko").setLevel(logging.INFO) + exception_classes = [errors.ExecuteUserScriptError] class SMDataParallelRunner(process.ProcessRunner): @@ -173,7 +173,7 @@ def _get_mpirun_command( "-x", "RDMAV_FORK_SAFE=1", "-x", - "LD_PRELOAD=%s" % inspect.getfile(gethostname), + "LD_PRELOAD=%s" % getfile(gethostname), ] mpirun_command.extend(additional_options) diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py index 4b7c4f2d..4c6ca50a 100644 --- a/test/unit/test_environment.py +++ b/test/unit/test_environment.py @@ -212,6 +212,7 @@ def test_env_mapping_properties(training_env): "output_intermediate_dir", "is_master", "master_hostname", + "is_modelparallel_enabled", } diff --git a/test/unit/test_process.py b/test/unit/test_process.py index 3f0c10d2..86b9095e 100644 --- a/test/unit/test_process.py +++ b/test/unit/test_process.py @@ -105,7 +105,7 @@ async def test_watch(event_loop, capsys): expected_stream += ( "[1,mpirank:0,algo-1]:FileNotFoundError: [Errno 2] No such file or directory\n" ) - expected_errmsg = ":FileNotFoundError: [Errno 2] No such file or directory\n" + expected_errmsg = "FileNotFoundError: [Errno 2] No such file or directory\n" stream = asyncio.StreamReader() stream.feed_data(b"[1,10]:This is stdout\n") @@ -119,6 +119,53 @@ async def test_watch(event_loop, capsys): assert output == expected_errmsg +@pytest.mark.asyncio +async def test_watch_custom_error(event_loop, capsys): + num_processes_per_host = 8 + expected_stream = "[1,mpirank:10,algo-2]:This is stdout\n" + expected_stream += "[1,mpirank:10,algo-2]:This is stderr\n" + expected_stream += "[1,mpirank:0,algo-1]:SMDDPNCCLError: unhandled cuda error\n" + expected_errmsg = "SMDDPNCCLError: unhandled cuda error\n" + + stream = asyncio.StreamReader() + stream.feed_data(b"[1,10]:This is stdout\n") + stream.feed_data(b"[1,10]:This is stderr\n") + stream.feed_data(b"[1,0]:SMDDPNCCLError: unhandled cuda error") + stream.feed_eof() + + error_classes = ["SMDDPNCCLError"] + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + captured_stream = capsys.readouterr() + assert captured_stream.out == expected_stream + assert output == expected_errmsg + + # test errors piped in stdout + stream = asyncio.StreamReader() + stream.feed_data(b"[1,0]:SMDDPNCCLError: unhandled cuda error") + stream.feed_eof() + + error_classes = ["SMDDPNCCLError"] + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + assert output == expected_errmsg + + # test single item + stream = asyncio.StreamReader() + stream.feed_data(b"[1,0]:SMDDPNCCLError: unhandled cuda error") + stream.feed_eof() + error_classes = "SMDDPNCCLError" + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + assert output == expected_errmsg + + # test internal error + expected_errmsg = "ImportModuleError: module does not exist\n" + stream = asyncio.StreamReader() + stream.feed_data(b"[1,0]:ImportModuleError: module does not exist") + stream.feed_eof() + error_classes = [errors.ImportModuleError] + output = await process.watch(stream, num_processes_per_host, error_classes=error_classes) + assert output == expected_errmsg + + @patch("asyncio.run", AsyncMock(side_effect=ValueError("FAIL"))) def test_create_error(): with pytest.raises(errors.ExecuteUserScriptError):