3636_DEFAULT_BUF_SIZE = 1024 * 64
3737
3838
39- async def watch (stream , error_classes , proc_per_host ):
39+ async def watch (stream , proc_per_host , error_classes = None ):
4040 """Process the stdout and stderr streams on the fly.
4141 Decode the output lines
4242 Remove new line characters (if any)
@@ -45,12 +45,16 @@ async def watch(stream, error_classes, proc_per_host):
4545
4646 Args:
4747 stream: asyncio subprocess PIPE
48- error_classes (list): List of exception classes to watch and raise
4948 proc_per_host (int): Number of processes per each host
49+ error_classes (list): List of exception classes to watch and raise
5050
5151 Returns:
5252 output: Filtered stderr
5353 """
54+ if not error_classes :
55+ error_classes = []
56+ if not isinstance (error_classes , list ):
57+ error_classes = [error_classes ]
5458 output = []
5559 buf_size = _DEFAULT_BUF_SIZE
5660 start = False
@@ -105,17 +109,17 @@ async def watch(stream, error_classes, proc_per_host):
105109 return " " .join (output )
106110
107111
108- async def run_async (cmd , error_classes , processes_per_host , env , cwd , stderr , ** kwargs ):
112+ async def run_async (cmd , processes_per_host , env , cwd , stderr , error_classes = None , ** kwargs ):
109113 """Method responsible for launching asyncio subprocess shell
110- Use asyncio gather to collect processed stdout and stderr
114+ Usyncse asyncio gather to collect processed stdout and stderr
111115
112116 Args:
113117 cmd (list): The command to be run
114- error_classes (list): List of exception classes to watch and raise
115118 processes_per_host (int): Number of processes per host
116119 env: os.environ
117120 cwd (str): The location from which to run the command (default: None).
118121 If None, this defaults to the ``code_dir`` of the environment.
122+ error_classes (list): List of exception classes to watch and raise
119123 **kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.
120124
121125 Returns:
@@ -126,14 +130,18 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **
126130 Raises:
127131 ExecuteUserScriptError: If there is an exception raised when creating the process.
128132 """
133+ if not error_classes :
134+ error_classes = []
135+ if not isinstance (error_classes , list ):
136+ error_classes = [error_classes ]
129137 cmd = " " .join (cmd )
130138 proc = await asyncio .create_subprocess_shell (
131139 cmd , env = env , cwd = cwd , stdout = PIPE , stderr = stderr , ** kwargs
132140 )
133141
134142 output = await asyncio .gather (
135- watch (proc .stdout , error_classes , processes_per_host ),
136- watch (proc .stderr , error_classes , processes_per_host ),
143+ watch (proc .stdout , processes_per_host , error_classes = error_classes ),
144+ watch (proc .stderr , processes_per_host , error_classes = error_classes ),
137145 )
138146 return_code = proc .returncode
139147 return return_code , output , proc
@@ -166,16 +174,20 @@ def create(
166174 Raises:
167175 ExecuteUserScriptError: If there is an exception raised when creating the process.
168176 """
177+ if not error_classes :
178+ error_classes = []
179+ if not isinstance (error_classes , list ):
180+ error_classes = [error_classes ]
169181 try :
170182 stderr = PIPE if capture_error else None
171183 rc , output , proc = asyncio .run (
172184 run_async (
173185 cmd ,
174- error_classes ,
175186 processes_per_host ,
176187 env = env or os .environ ,
177188 cwd = cwd or environment .code_dir ,
178189 stderr = stderr ,
190+ error_classes = error_classes ,
179191 ** kwargs ,
180192 )
181193 )
@@ -206,7 +218,10 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=
206218 Raises:
207219 ExecuteUserScriptError: If there is an exception raised when creating the process.
208220 """
209-
221+ if not error_classes :
222+ error_classes = []
223+ if not isinstance (error_classes , list ):
224+ error_classes = [error_classes ]
210225 if capture_error :
211226 return_code , output , process = create (
212227 cmd ,
@@ -239,10 +254,11 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=
239254 # default error class will be user script error
240255 error_class = errors .ExecuteUserScriptError
241256 # use first found target error class if available
242- for error_name in error_classes :
243- if error_name in stderr :
244- error_class = type (error_name , (errors .ExecuteUserScriptError ,), {})
245- break
257+ if stderr :
258+ for error_name in error_classes :
259+ if str (error_name ) in stderr :
260+ error_class = type (error_name , (errors .ExecuteUserScriptError ,), {})
261+ break
246262
247263 raise error_class (
248264 cmd = " " .join (cmd ) if isinstance (cmd , list ) else cmd ,
0 commit comments