Skip to content

Commit 128dbdd

Browse files
committed
feature: allow framework libraries to supply exceptions to track and report as failure reason.
Added support for SMDDP and SMMP custom exceptions. Include custom exception as error class and de-duplicated stack trace as error message. Added tests for wacthing single, list of exceptions and also support existing internal exceptions.
1 parent 697a887 commit 128dbdd

File tree

7 files changed

+131
-45
lines changed

7 files changed

+131
-45
lines changed

src/sagemaker_training/environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
563563
self._is_master = current_host == self._master_hostname
564564

565565
mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS)
566-
self._is_modelparallel_enabled = mp_parameters and mp_parameters != '{}'
566+
self._is_modelparallel_enabled = mp_parameters and mp_parameters != "{}"
567567

568568
@property
569569
def model_dir(self): # type: () -> str
@@ -921,6 +921,7 @@ def is_modelparallel_enabled(self): # type: () -> bool
921921
"""
922922
return self._is_modelparallel_enabled
923923

924+
924925
def write_env_vars(env_vars=None): # type: (dict) -> None
925926
"""Write the dictionary env_vars in the system, as environment variables.
926927

src/sagemaker_training/mpi.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains functionality related to distributed training using
1414
MPI (Message Passing Interface)."""
1515
import argparse
16-
import inspect
16+
from inspect import getfile, isclass
1717
import logging
1818
import os
1919
import subprocess
@@ -24,23 +24,24 @@
2424

2525
import gethostname
2626
from sagemaker_training import environment, errors, logging_config, process, timeout
27-
from inspect import isclass
2827

2928
logger = logging_config.get_logger()
3029
logging.getLogger("paramiko").setLevel(logging.INFO)
3130

3231
try:
3332
from smdistributed.modelparallel.backend import exceptions
33+
3434
# list of exceptions SMMP wants training toolkit to catch and log
3535
exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))]
36-
except ImportError as e:
36+
except ImportError:
3737
logger.info("No exception classes found in smdistributed.modelparallel")
3838
exception_classes = []
3939
try:
4040
from smdistributed.modelparallel.torch import exceptions as torch_exceptions
41+
4142
# list of torch exceptions SMMP wants training toolkit to catch and log
4243
exception_classes += [x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))]
43-
except ImportError as e:
44+
except ImportError:
4445
logger.info("No torch exception classes found in smdistributed.modelparallel")
4546

4647

@@ -250,12 +251,16 @@ def _create_command(self):
250251
"-x",
251252
"PATH",
252253
"-x",
253-
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
254+
"LD_PRELOAD=%s" % getfile(gethostname),
254255
]
255256

256257
command.extend(additional_options)
257258

258-
for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]:
259+
for credential in [
260+
"AWS_ACCESS_KEY_ID",
261+
"AWS_SECRET_ACCESS_KEY",
262+
"AWS_SESSION_TOKEN",
263+
]:
259264
if credential in os.environ:
260265
command.extend(["-x", credential])
261266

@@ -293,15 +298,19 @@ def run(self, wait=True, capture_error=False):
293298
if wait:
294299
process_spawned = process.check_error(
295300
cmd,
296-
exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError,
301+
exception_classes
302+
if training_env.is_modelparallel_enabled
303+
else errors.ExecuteUserScriptError,
297304
self._processes_per_host,
298305
capture_error=capture_error,
299306
cwd=environment.code_dir,
300307
)
301308
else:
302309
_, _, process_spawned = process.create(
303310
cmd,
304-
exception_classes if training_env.is_modelparallel_enabled else errors.ExecuteUserScriptError,
311+
exception_classes
312+
if training_env.is_modelparallel_enabled
313+
else errors.ExecuteUserScriptError,
305314
self._processes_per_host,
306315
capture_error=capture_error,
307316
cwd=environment.code_dir,
@@ -311,7 +320,6 @@ def run(self, wait=True, capture_error=False):
311320
return process_spawned
312321

313322

314-
315323
_SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """
316324
SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster.
317325

src/sagemaker_training/params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@
5959
SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = (
6060
"sagemaker_distributed_dataparallel_custom_mpi_options"
6161
) # type: str
62-
SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS"
62+
SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS"

src/sagemaker_training/process.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
from asyncio.subprocess import PIPE
20+
from inspect import isclass
2021
import os
2122
import re
2223
import subprocess
@@ -36,7 +37,24 @@
3637
_DEFAULT_BUF_SIZE = 1024 * 64
3738

3839

39-
async def watch(stream, error_classes, proc_per_host):
40+
def process_error_classes(error_classes):
41+
"""Process error classes and return a list of string.
42+
Input could be class, string, or None
43+
44+
Args:
45+
error_classes (list): List of error classes
46+
47+
Returns:
48+
error_classes: processed classes
49+
"""
50+
if not error_classes:
51+
return []
52+
if not isinstance(error_classes, list):
53+
error_classes = [error_classes]
54+
return [error.__name__ if isclass(error) else error for error in error_classes]
55+
56+
57+
async def watch(stream, proc_per_host, error_classes=None):
4058
"""Process the stdout and stderr streams on the fly.
4159
Decode the output lines
4260
Remove new line characters (if any)
@@ -45,12 +63,13 @@ async def watch(stream, error_classes, proc_per_host):
4563
4664
Args:
4765
stream: asyncio subprocess PIPE
48-
error_classes (list): List of exception classes to watch and raise
4966
proc_per_host (int): Number of processes per each host
67+
error_classes (list): List of exception classes to watch and raise
5068
5169
Returns:
5270
output: Filtered stderr
5371
"""
72+
error_classes = process_error_classes(error_classes)
5473
output = []
5574
buf_size = _DEFAULT_BUF_SIZE
5675
start = False
@@ -90,25 +109,32 @@ async def watch(stream, error_classes, proc_per_host):
90109
if err_line not in output:
91110
output.append(err_line.strip(" :\n") + "\n")
92111
else:
93-
if any(str(err) in err_line for err in (_PYTHON_ERRORS_ + error_classes if type(error_classes) == list else [error_classes])):
112+
if any(
113+
str(err) in err_line
114+
for err in (
115+
_PYTHON_ERRORS_ + error_classes
116+
if isinstance(error_classes, list)
117+
else [error_classes]
118+
)
119+
):
94120
# start logging error message if target exceptions found
95121
start = True
96122
output.append(err_line.strip(" :\n") + "\n")
97123

98124
return " ".join(output)
99125

100126

101-
async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **kwargs):
127+
async def run_async(cmd, processes_per_host, env, cwd, stderr, error_classes=None, **kwargs):
102128
"""Method responsible for launching asyncio subprocess shell
103-
Use asyncio gather to collect processed stdout and stderr
129+
Usyncse asyncio gather to collect processed stdout and stderr
104130
105131
Args:
106132
cmd (list): The command to be run
107-
error_classes (list): List of exception classes to watch and raise
108133
processes_per_host (int): Number of processes per host
109134
env: os.environ
110135
cwd (str): The location from which to run the command (default: None).
111136
If None, this defaults to the ``code_dir`` of the environment.
137+
error_classes (list): List of exception classes to watch and raise
112138
**kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.
113139
114140
Returns:
@@ -125,8 +151,8 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **
125151
)
126152

127153
output = await asyncio.gather(
128-
watch(proc.stdout, error_classes, processes_per_host),
129-
watch(proc.stderr, error_classes, processes_per_host),
154+
watch(proc.stdout, processes_per_host, error_classes=error_classes),
155+
watch(proc.stderr, processes_per_host, error_classes=error_classes),
130156
)
131157
return_code = proc.returncode
132158
return return_code, output, proc
@@ -164,11 +190,11 @@ def create(
164190
rc, output, proc = asyncio.run(
165191
run_async(
166192
cmd,
167-
error_classes,
168193
processes_per_host,
169194
env=env or os.environ,
170195
cwd=cwd or environment.code_dir,
171196
stderr=stderr,
197+
error_classes=error_classes,
172198
**kwargs,
173199
)
174200
)
@@ -181,9 +207,7 @@ def create(
181207
)
182208

183209

184-
def check_error(
185-
cmd, error_classes, processes_per_host, cwd=None, capture_error=True, **kwargs
186-
):
210+
def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=True, **kwargs):
187211
"""Run a commmand, raising an exception if there is an error.
188212
189213
Args:
@@ -201,7 +225,7 @@ def check_error(
201225
Raises:
202226
ExecuteUserScriptError: If there is an exception raised when creating the process.
203227
"""
204-
228+
error_classes = process_error_classes(error_classes)
205229
if capture_error:
206230
return_code, output, process = create(
207231
cmd,
@@ -231,13 +255,24 @@ def check_error(
231255
extra_info = None
232256
if return_code == 137:
233257
extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)"
234-
# default error class will be user script error
235-
error_class = errors.ExecuteUserScriptError
236-
# use first found target error class if available
237-
for error_name in error_classes:
238-
if error_name in stderr:
239-
error_class = type(error_name, (errors._CalledProcessError,), {})
240-
break
258+
259+
# throw internal error classes first
260+
internal_errors = [err for err in dir(errors) if isclass(getattr(errors, err))]
261+
error_class = next(
262+
(name for name in error_classes if name in internal_errors), "ExecuteUserScriptError"
263+
)
264+
error_class = getattr(errors, error_class)
265+
266+
# only replace ExecuteUserScriptError with custom library errors
267+
if stderr and error_class == errors.ExecuteUserScriptError:
268+
# find the first target error in stderr
269+
error_name = next((str(name) for name in error_classes if str(name) in stderr), False)
270+
if error_name:
271+
error_class = type(
272+
error_name,
273+
(errors._CalledProcessError,), # pylint: disable=protected-access
274+
{},
275+
)
241276

242277
raise error_class(
243278
cmd=" ".join(cmd) if isinstance(cmd, list) else cmd,
@@ -257,9 +292,7 @@ def python_executable():
257292
(str): The real path of the current Python executable.
258293
"""
259294
if not sys.executable:
260-
raise RuntimeError(
261-
"Failed to retrieve the real path for the Python executable binary"
262-
)
295+
raise RuntimeError("Failed to retrieve the real path for the Python executable binary")
263296
return sys.executable
264297

265298

@@ -281,9 +314,7 @@ def __init__(self, user_entry_point, args, env_vars, processes_per_host):
281314
self._processes_per_host = processes_per_host
282315

283316
def _create_command(self):
284-
entrypoint_type = _entry_point_type.get(
285-
environment.code_dir, self._user_entry_point
286-
)
317+
entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point)
287318

288319
if entrypoint_type is _entry_point_type.PYTHON_PACKAGE:
289320
entry_module = self._user_entry_point.replace(".py", "")

src/sagemaker_training/smdataparallel.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Contains functionality related to SM Distributed Data Parallel Training."""
1414
import argparse
15-
import inspect
15+
from inspect import getfile, isclass
1616
import json
1717
import logging
1818
import os
@@ -23,7 +23,7 @@
2323

2424
import gethostname
2525
from sagemaker_training import environment, errors, logging_config, process, timeout
26-
from inspect import isclass
26+
2727

2828
logger = logging_config.get_logger()
2929
logging.getLogger("paramiko").setLevel(logging.INFO)
@@ -33,7 +33,7 @@
3333

3434
# list of exceptions SMDDP wants training toolkit to catch and log
3535
exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))]
36-
except ImportError as e:
36+
except ImportError:
3737
logger.info("No exception classes found in smdistributed.dataparallel")
3838
exception_classes = [errors.ExecuteUserScriptError]
3939

@@ -173,7 +173,7 @@ def _get_mpirun_command(
173173
"-x",
174174
"RDMAV_FORK_SAFE=1",
175175
"-x",
176-
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
176+
"LD_PRELOAD=%s" % getfile(gethostname),
177177
]
178178

179179
mpirun_command.extend(additional_options)
@@ -231,9 +231,7 @@ def _create_command(self):
231231
# homogeneous mode uses 16 processes per host; 8 server; 8 worker
232232
smdataparallel_server_addr = self._master_hostname
233233
smdataparallel_server_port = 7592
234-
host_list = [
235-
"{}:{}".format(host, num_processes_per_host) for host in self._hosts
236-
]
234+
host_list = ["{}:{}".format(host, num_processes_per_host) for host in self._hosts]
237235
smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1"
238236
command = self._get_mpirun_command(
239237
num_hosts,

test/unit/test_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def test_env_mapping_properties(training_env):
212212
"output_intermediate_dir",
213213
"is_master",
214214
"master_hostname",
215+
"is_modelparallel_enabled",
215216
}
216217

217218

0 commit comments

Comments
 (0)