@@ -76,6 +76,37 @@ def empty_sagemaker_session():
7676 return ims
7777
7878
79+ def empty_sagemaker_session_with_null_failure_path ():
80+ ims = Mock (name = "sagemaker_session" )
81+ ims .default_bucket = Mock (name = "default_bucket" , return_value = BUCKET_NAME )
82+ ims .sagemaker_runtime_client = Mock (name = "sagemaker_runtime" )
83+ ims .sagemaker_client .describe_endpoint = Mock (return_value = ENDPOINT_DESC )
84+ ims .sagemaker_client .describe_endpoint_config = Mock (return_value = ENDPOINT_CONFIG_DESC )
85+
86+ ims .sagemaker_runtime_client .invoke_endpoint_async = Mock (
87+ name = "invoke_endpoint_async" ,
88+ return_value = {
89+ "OutputLocation" : ASYNC_OUTPUT_LOCATION ,
90+ },
91+ )
92+
93+ polling_timeout_error = PollingTimeoutError (
94+ message = "Inference could still be running" ,
95+ output_path = ASYNC_OUTPUT_LOCATION ,
96+ seconds = DEFAULT_WAITER_CONFIG .delay * DEFAULT_WAITER_CONFIG .max_attempts ,
97+ )
98+
99+ ims .s3_client = Mock (name = "s3_client" )
100+ ims .s3_client .get_object = Mock (
101+ name = "get_object" ,
102+ side_effect = [polling_timeout_error ],
103+ )
104+
105+ ims .s3_client .put_object = Mock (name = "put_object" )
106+
107+ return ims
108+
109+
79110def empty_predictor ():
80111 predictor = Mock (name = "predictor" )
81112 predictor .update_endpoint = Mock (name = "update_endpoint" )
@@ -161,6 +192,31 @@ def test_async_predict_call_with_data_and_input_path():
161192 assert result .failure_path == ASYNC_FAILURE_LOCATION
162193
163194
195+ def test_async_predict_call_with_data_and_input_and_null_failure_path ():
196+ sagemaker_session = empty_sagemaker_session_with_null_failure_path ()
197+ predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
198+ predictor_async .name = ASYNC_PREDICTOR
199+ data = DUMMY_DATA
200+
201+ result = predictor_async .predict_async (data = data , input_path = ASYNC_INPUT_LOCATION )
202+ assert sagemaker_session .s3_client .put_object .called
203+
204+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .called
205+ assert sagemaker_session .sagemaker_client .describe_endpoint .not_called
206+ assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
207+
208+ expected_request_args = {
209+ "Accept" : DEFAULT_ACCEPT ,
210+ "InputLocation" : ASYNC_INPUT_LOCATION ,
211+ "EndpointName" : ENDPOINT ,
212+ }
213+
214+ call_args , kwargs = sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .call_args
215+ assert kwargs == expected_request_args
216+ assert result .output_path == ASYNC_OUTPUT_LOCATION
217+ assert result .failure_path is None
218+
219+
164220def test_async_predict_call_verify_exceptions ():
165221 sagemaker_session = empty_sagemaker_session ()
166222 predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
@@ -185,7 +241,27 @@ def test_async_predict_call_verify_exceptions():
185241 assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
186242
187243
188- def test_async_predict_call_pass_through_success ():
244+ def test_async_predict_call_verify_exceptions_with_null_failure_path ():
245+ sagemaker_session = empty_sagemaker_session_with_null_failure_path ()
246+ predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
247+
248+ input_location = "s3://some-input-path"
249+
250+ with pytest .raises (
251+ PollingTimeoutError ,
252+ match = f"No result at { ASYNC_OUTPUT_LOCATION } after polling for "
253+ f"{ DEFAULT_WAITER_CONFIG .delay * DEFAULT_WAITER_CONFIG .max_attempts } "
254+ f" seconds. Inference could still be running" ,
255+ ):
256+ predictor_async .predict (input_path = input_location , waiter_config = DEFAULT_WAITER_CONFIG )
257+
258+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .called
259+ assert sagemaker_session .s3_client .get_object .called
260+ assert sagemaker_session .sagemaker_client .describe_endpoint .not_called
261+ assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
262+
263+
264+ def test_async_predict_call_pass_through_output_failure_paths ():
189265 sagemaker_session = empty_sagemaker_session ()
190266
191267 response_body = Mock ("body" )
@@ -222,6 +298,42 @@ def test_async_predict_call_pass_through_success():
222298 assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
223299
224300
301+ def test_async_predict_call_pass_through_with_null_failure_path ():
302+ sagemaker_session = empty_sagemaker_session_with_null_failure_path ()
303+
304+ response_body = Mock ("body" )
305+ response_body .read = Mock ("read" , return_value = RETURN_VALUE )
306+ response_body .close = Mock ("close" , return_value = None )
307+
308+ sagemaker_session .s3_client = Mock (name = "s3_client" )
309+ sagemaker_session .s3_client .get_object = Mock (
310+ name = "get_object" ,
311+ return_value = {"Body" : response_body },
312+ )
313+ sagemaker_session .s3_client .put_object = Mock (name = "put_object" )
314+
315+ predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
316+
317+ sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async = Mock (
318+ name = "invoke_endpoint_async" ,
319+ return_value = {
320+ "OutputLocation" : ASYNC_OUTPUT_LOCATION ,
321+ },
322+ )
323+
324+ input_location = "s3://some-input-path"
325+
326+ result = predictor_async .predict (
327+ input_path = input_location ,
328+ )
329+
330+ assert result == RETURN_VALUE
331+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .called
332+ assert sagemaker_session .s3_client .get_waiter .called_with ("object_exists" )
333+ assert sagemaker_session .sagemaker_client .describe_endpoint .not_called
334+ assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
335+
336+
225337def test_predict_async_call_invalid_input ():
226338 sagemaker_session = empty_sagemaker_session ()
227339 predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
0 commit comments