@@ -99,7 +99,7 @@ def predict(
9999 self ._input_path = input_path
100100 response = self ._submit_async_request (input_path , initial_args , inference_id )
101101 output_location = response ["OutputLocation" ]
102- failure_location = response [ "FailureLocation" ]
102+ failure_location = response . get ( "FailureLocation" )
103103 result = self ._wait_for_output (
104104 output_path = output_location , failure_path = failure_location , waiter_config = waiter_config
105105 )
@@ -145,7 +145,7 @@ def predict_async(
145145 self ._input_path = input_path
146146 response = self ._submit_async_request (input_path , initial_args , inference_id )
147147 output_location = response ["OutputLocation" ]
148- failure_location = response [ "FailureLocation" ]
148+ failure_location = response . get ( "FailureLocation" )
149149 response_async = AsyncInferenceResponse (
150150 predictor_async = self ,
151151 output_path = output_location ,
@@ -216,6 +216,35 @@ def _submit_async_request(
216216 return response
217217
218218 def _wait_for_output (self , output_path , failure_path , waiter_config ):
219+ """Retrieve output based on the presense of failure_path."""
220+ if failure_path is not None :
221+ return self ._check_output_and_failure_locations (
222+ output_path , failure_path , waiter_config
223+ )
224+
225+ return self ._check_output_location (output_path , waiter_config )
226+
227+ def _check_output_location (self , output_path , waiter_config ):
228+ """Check the Amazon S3 output path for the output.
229+
230+ Periodically check Amazon S3 output path for async inference result.
231+ Timeout automatically after max attempts reached
232+ """
233+ bucket , key = parse_s3_url (output_path )
234+ s3_waiter = self .s3_client .get_waiter ("object_exists" )
235+ try :
236+ s3_waiter .wait (Bucket = bucket , Key = key , WaiterConfig = waiter_config ._to_request_dict ())
237+ except WaiterError :
238+ raise PollingTimeoutError (
239+ message = "Inference could still be running" ,
240+ output_path = output_path ,
241+ seconds = waiter_config .delay * waiter_config .max_attempts ,
242+ )
243+ s3_object = self .s3_client .get_object (Bucket = bucket , Key = key )
244+ result = self .predictor ._handle_response (response = s3_object )
245+ return result
246+
247+ def _check_output_and_failure_locations (self , output_path , failure_path , waiter_config ):
219248 """Check the Amazon S3 output path for the output.
220249
221250 This method waits for either the output file or the failure file to be found on the
0 commit comments