1717
1818import asyncio
1919from asyncio .subprocess import PIPE
20+ from inspect import isclass
2021import os
2122import re
2223import subprocess
3637_DEFAULT_BUF_SIZE = 1024 * 64
3738
3839
39- async def watch (stream , error_classes , proc_per_host ):
40+ def process_error_classes (error_classes ):
41+ """Process error classes and return a list of string.
42+ Input could be class, string, or None
43+
44+ Args:
45+ error_classes (list): List of error classes
46+
47+ Returns:
48+ error_classes: processed classes
49+ """
50+ if not error_classes :
51+ return []
52+ if not isinstance (error_classes , list ):
53+ error_classes = [error_classes ]
54+ return [error .__name__ if isclass (error ) else error for error in error_classes ]
55+
56+
57+ async def watch (stream , proc_per_host , error_classes = None ):
4058 """Process the stdout and stderr streams on the fly.
4159 Decode the output lines
4260 Remove new line characters (if any)
@@ -45,12 +63,13 @@ async def watch(stream, error_classes, proc_per_host):
4563
4664 Args:
4765 stream: asyncio subprocess PIPE
48- error_classes (list): List of exception classes to watch and raise
4966 proc_per_host (int): Number of processes per each host
67+ error_classes (list): List of exception classes to watch and raise
5068
5169 Returns:
5270 output: Filtered stderr
5371 """
72+ error_classes = process_error_classes (error_classes )
5473 output = []
5574 buf_size = _DEFAULT_BUF_SIZE
5675 start = False
@@ -90,25 +109,32 @@ async def watch(stream, error_classes, proc_per_host):
90109 if err_line not in output :
91110 output .append (err_line .strip (" :\n " ) + "\n " )
92111 else :
93- if any (str (err ) in err_line for err in (_PYTHON_ERRORS_ + error_classes if type (error_classes ) == list else [error_classes ])):
112+ if any (
113+ str (err ) in err_line
114+ for err in (
115+ _PYTHON_ERRORS_ + error_classes
116+ if isinstance (error_classes , list )
117+ else [error_classes ]
118+ )
119+ ):
94120 # start logging error message if target exceptions found
95121 start = True
96122 output .append (err_line .strip (" :\n " ) + "\n " )
97123
98124 return " " .join (output )
99125
100126
101- async def run_async (cmd , error_classes , processes_per_host , env , cwd , stderr , ** kwargs ):
127+ async def run_async (cmd , processes_per_host , env , cwd , stderr , error_classes = None , ** kwargs ):
102128 """Method responsible for launching asyncio subprocess shell
103- Use asyncio gather to collect processed stdout and stderr
129+ Usyncse asyncio gather to collect processed stdout and stderr
104130
105131 Args:
106132 cmd (list): The command to be run
107- error_classes (list): List of exception classes to watch and raise
108133 processes_per_host (int): Number of processes per host
109134 env: os.environ
110135 cwd (str): The location from which to run the command (default: None).
111136 If None, this defaults to the ``code_dir`` of the environment.
137+ error_classes (list): List of exception classes to watch and raise
112138 **kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.
113139
114140 Returns:
@@ -125,8 +151,8 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **
125151 )
126152
127153 output = await asyncio .gather (
128- watch (proc .stdout , error_classes , processes_per_host ),
129- watch (proc .stderr , error_classes , processes_per_host ),
154+ watch (proc .stdout , processes_per_host , error_classes = error_classes ),
155+ watch (proc .stderr , processes_per_host , error_classes = error_classes ),
130156 )
131157 return_code = proc .returncode
132158 return return_code , output , proc
@@ -164,11 +190,11 @@ def create(
164190 rc , output , proc = asyncio .run (
165191 run_async (
166192 cmd ,
167- error_classes ,
168193 processes_per_host ,
169194 env = env or os .environ ,
170195 cwd = cwd or environment .code_dir ,
171196 stderr = stderr ,
197+ error_classes = error_classes ,
172198 ** kwargs ,
173199 )
174200 )
@@ -181,9 +207,7 @@ def create(
181207 )
182208
183209
184- def check_error (
185- cmd , error_classes , processes_per_host , cwd = None , capture_error = True , ** kwargs
186- ):
210+ def check_error (cmd , error_classes , processes_per_host , cwd = None , capture_error = True , ** kwargs ):
187211 """Run a commmand, raising an exception if there is an error.
188212
189213 Args:
@@ -201,7 +225,7 @@ def check_error(
201225 Raises:
202226 ExecuteUserScriptError: If there is an exception raised when creating the process.
203227 """
204-
228+ error_classes = process_error_classes ( error_classes )
205229 if capture_error :
206230 return_code , output , process = create (
207231 cmd ,
@@ -231,13 +255,24 @@ def check_error(
231255 extra_info = None
232256 if return_code == 137 :
233257 extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)"
234- # default error class will be user script error
235- error_class = errors .ExecuteUserScriptError
236- # use first found target error class if available
237- for error_name in error_classes :
238- if error_name in stderr :
239- error_class = type (error_name , (errors ._CalledProcessError ,), {})
240- break
258+
259+ # throw internal error classes first
260+ internal_errors = [err for err in dir (errors ) if isclass (getattr (errors , err ))]
261+ error_class = next (
262+ (name for name in error_classes if name in internal_errors ), "ExecuteUserScriptError"
263+ )
264+ error_class = getattr (errors , error_class )
265+
266+ # only replace ExecuteUserScriptError with custom library errors
267+ if stderr and error_class == errors .ExecuteUserScriptError :
268+ # find the first target error in stderr
269+ error_name = next ((str (name ) for name in error_classes if str (name ) in stderr ), False )
270+ if error_name :
271+ error_class = type (
272+ error_name ,
273+ (errors ._CalledProcessError ,), # pylint: disable=protected-access
274+ {},
275+ )
241276
242277 raise error_class (
243278 cmd = " " .join (cmd ) if isinstance (cmd , list ) else cmd ,
@@ -257,9 +292,7 @@ def python_executable():
257292 (str): The real path of the current Python executable.
258293 """
259294 if not sys .executable :
260- raise RuntimeError (
261- "Failed to retrieve the real path for the Python executable binary"
262- )
295+ raise RuntimeError ("Failed to retrieve the real path for the Python executable binary" )
263296 return sys .executable
264297
265298
@@ -281,9 +314,7 @@ def __init__(self, user_entry_point, args, env_vars, processes_per_host):
281314 self ._processes_per_host = processes_per_host
282315
283316 def _create_command (self ):
284- entrypoint_type = _entry_point_type .get (
285- environment .code_dir , self ._user_entry_point
286- )
317+ entrypoint_type = _entry_point_type .get (environment .code_dir , self ._user_entry_point )
287318
288319 if entrypoint_type is _entry_point_type .PYTHON_PACKAGE :
289320 entry_module = self ._user_entry_point .replace (".py" , "" )
0 commit comments