From ba959bb121f86b3bc1200b09f57b0820f8c2f56a Mon Sep 17 00:00:00 2001 From: Praneeth Dodda Date: Wed, 17 May 2023 15:36:44 -0500 Subject: [PATCH 1/2] feature: handle use case where endpoint is created outside of python sdk with failure path as None --- .../async_inference_response.py | 23 ++++ src/sagemaker/predictor_async.py | 33 ++++- .../test_async_inference_response.py | 70 +++++++++++ tests/unit/test_predictor_async.py | 114 +++++++++++++++++- 4 files changed, 237 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/async_inference/async_inference_response.py b/src/sagemaker/async_inference/async_inference_response.py index fb195597c9..e9f82db2d0 100644 --- a/src/sagemaker/async_inference/async_inference_response.py +++ b/src/sagemaker/async_inference/async_inference_response.py @@ -87,8 +87,31 @@ def get_result( return self._result def _get_result_from_s3(self, output_path, failure_path): + """Retrieve output based on the presense of failure_path""" + if failure_path is not None: + return self._get_result_from_s3_output_failure_paths(output_path, failure_path) + + return self._get_result_from_s3_output_path(output_path) + + def _get_result_from_s3_output_path(self, output_path): """Get inference result from the output Amazon S3 path""" bucket, key = parse_s3_url(output_path) + try: + response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key) + return self.predictor_async.predictor._handle_response(response) + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise ObjectNotExistedError( + message="Inference could still be running", + output_path=output_path, + ) + raise UnexpectedClientError( + message=ex.response["Error"]["Message"], + ) + + def _get_result_from_s3_output_failure_paths(self, output_path, failure_path): + """Get inference result from the output & failure Amazon S3 path""" + bucket, key = parse_s3_url(output_path) try: response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key) return self.predictor_async.predictor._handle_response(response) diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index 2426b86a5c..5b5ed532be 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -99,7 +99,7 @@ def predict( self._input_path = input_path response = self._submit_async_request(input_path, initial_args, inference_id) output_location = response["OutputLocation"] - failure_location = response["FailureLocation"] + failure_location = response.get("FailureLocation") result = self._wait_for_output( output_path=output_location, failure_path=failure_location, waiter_config=waiter_config ) @@ -145,7 +145,7 @@ def predict_async( self._input_path = input_path response = self._submit_async_request(input_path, initial_args, inference_id) output_location = response["OutputLocation"] - failure_location = response["FailureLocation"] + failure_location = response.get("FailureLocation") response_async = AsyncInferenceResponse( predictor_async=self, output_path=output_location, @@ -216,6 +216,35 @@ def _submit_async_request( return response def _wait_for_output(self, output_path, failure_path, waiter_config): + """Retrieve output based on the presense of failure_path.""" + if failure_path is not None: + return self._check_output_and_failure_paths( + output_path, failure_path, waiter_config + ) + + return self._check_output_path(output_path, waiter_config) + + def _check_output_path(self, output_path, waiter_config): + """Check the Amazon S3 output path for the output. + + Periodically check Amazon S3 output path for async inference result. + Timeout automatically after max attempts reached + """ + bucket, key = parse_s3_url(output_path) + s3_waiter = self.s3_client.get_waiter("object_exists") + try: + s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict()) + except WaiterError: + raise PollingTimeoutError( + message="Inference could still be running", + output_path=output_path, + seconds=waiter_config.delay * waiter_config.max_attempts, + ) + s3_object = self.s3_client.get_object(Bucket=bucket, Key=key) + result = self.predictor._handle_response(response=s3_object) + return result + + def _check_output_and_failure_paths(self, output_path, failure_path, waiter_config): """Check the Amazon S3 output path for the output. This method waits for either the output file or the failure file to be found on the diff --git a/tests/unit/sagemaker/async_inference/test_async_inference_response.py b/tests/unit/sagemaker/async_inference/test_async_inference_response.py index a1ad6cf4a8..555af84e89 100644 --- a/tests/unit/sagemaker/async_inference/test_async_inference_response.py +++ b/tests/unit/sagemaker/async_inference/test_async_inference_response.py @@ -74,6 +74,32 @@ def empty_s3_client(): return s3_client +def empty_s3_client_to_verify_exceptions_for_null_failure_path(): + """ + Returns a mocked S3 client with the `get_object` method overridden + to raise different exceptions based on the input. + + Exceptions raised: + - `ObjectNotExistedError` + - `UnexpectedClientError` + + """ + s3_client = Mock(name="s3-client") + + object_error = ObjectNotExistedError("Inference could still be running", DEFAULT_OUTPUT_PATH) + + unexpected_error = UnexpectedClientError("some error message") + + s3_client.get_object = Mock( + name="get_object", + side_effect=[ + object_error, + unexpected_error, + ], + ) + return s3_client + + def mock_s3_client(): """ This function returns a mocked S3 client object that has a get_object method with a side_effect @@ -172,3 +198,47 @@ def test_get_result_verify_exceptions(): UnexpectedClientError, match="Encountered unexpected client error: some error message" ): async_inference_response.get_result() + + +def test_get_result_with_null_failure_path(): + """ + verifies that the result is returned correctly if no errors occur. + """ + # Initialize AsyncInferenceResponse + predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME)) + predictor_async.s3_client = mock_s3_client() + async_inference_response = AsyncInferenceResponse( + output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, failure_path=None + ) + + result = async_inference_response.get_result() + assert async_inference_response._result == result + assert result == RETURN_VALUE + + +def test_get_result_verify_exceptions_with_null_failure_path(): + """ + Verifies that get_result method raises the expected exception + when an error occurs while fetching the result. + """ + # Initialize AsyncInferenceResponse + predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME)) + predictor_async.s3_client = empty_s3_client_to_verify_exceptions_for_null_failure_path() + async_inference_response = AsyncInferenceResponse( + output_path=DEFAULT_OUTPUT_PATH, + predictor_async=predictor_async, + failure_path=None, + ) + + # Test ObjectNotExistedError + with pytest.raises( + ObjectNotExistedError, + match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. Inference could still be running", + ): + async_inference_response.get_result() + + # Test UnexpectedClientError + with pytest.raises( + UnexpectedClientError, match="Encountered unexpected client error: some error message" + ): + async_inference_response.get_result() diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index f0b69abe93..6754506680 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -76,6 +76,37 @@ def empty_sagemaker_session(): return ims +def empty_sagemaker_session_with_null_failure_path(): + ims = Mock(name="sagemaker_session") + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + + ims.sagemaker_runtime_client.invoke_endpoint_async = Mock( + name="invoke_endpoint_async", + return_value={ + "OutputLocation": ASYNC_OUTPUT_LOCATION, + }, + ) + + polling_timeout_error = PollingTimeoutError( + message="Inference could still be running", + output_path=ASYNC_OUTPUT_LOCATION, + seconds=DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts, + ) + + ims.s3_client = Mock(name="s3_client") + ims.s3_client.get_object = Mock( + name="get_object", + side_effect=[polling_timeout_error], + ) + + ims.s3_client.put_object = Mock(name="put_object") + + return ims + + def empty_predictor(): predictor = Mock(name="predictor") predictor.update_endpoint = Mock(name="update_endpoint") @@ -161,6 +192,31 @@ def test_async_predict_call_with_data_and_input_path(): assert result.failure_path == ASYNC_FAILURE_LOCATION +def test_async_predict_call_with_data_and_input_and_null_failure_path(): + sagemaker_session = empty_sagemaker_session_with_null_failure_path() + predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) + predictor_async.name = ASYNC_PREDICTOR + data = DUMMY_DATA + + result = predictor_async.predict_async(data=data, input_path=ASYNC_INPUT_LOCATION) + assert sagemaker_session.s3_client.put_object.called + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called + assert sagemaker_session.sagemaker_client.describe_endpoint.not_called + assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called + + expected_request_args = { + "Accept": DEFAULT_ACCEPT, + "InputLocation": ASYNC_INPUT_LOCATION, + "EndpointName": ENDPOINT, + } + + call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args + assert kwargs == expected_request_args + assert result.output_path == ASYNC_OUTPUT_LOCATION + assert result.failure_path is None + + def test_async_predict_call_verify_exceptions(): sagemaker_session = empty_sagemaker_session() predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) @@ -185,7 +241,27 @@ def test_async_predict_call_verify_exceptions(): assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called -def test_async_predict_call_pass_through_success(): +def test_async_predict_call_verify_exceptions_with_null_failure_path(): + sagemaker_session = empty_sagemaker_session_with_null_failure_path() + predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) + + input_location = "s3://some-input-path" + + with pytest.raises( + PollingTimeoutError, + match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for " + f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}" + f" seconds. Inference could still be running", + ): + predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG) + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called + assert sagemaker_session.s3_client.get_object.called + assert sagemaker_session.sagemaker_client.describe_endpoint.not_called + assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called + + +def test_async_predict_call_pass_through_output_failure_paths(): sagemaker_session = empty_sagemaker_session() response_body = Mock("body") @@ -222,6 +298,42 @@ def test_async_predict_call_pass_through_success(): assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called +def test_async_predict_call_pass_through_with_null_failure_path(): + sagemaker_session = empty_sagemaker_session_with_null_failure_path() + + response_body = Mock("body") + response_body.read = Mock("read", return_value=RETURN_VALUE) + response_body.close = Mock("close", return_value=None) + + sagemaker_session.s3_client = Mock(name="s3_client") + sagemaker_session.s3_client.get_object = Mock( + name="get_object", + return_value={"Body": response_body}, + ) + sagemaker_session.s3_client.put_object = Mock(name="put_object") + + predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) + + sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async = Mock( + name="invoke_endpoint_async", + return_value={ + "OutputLocation": ASYNC_OUTPUT_LOCATION, + }, + ) + + input_location = "s3://some-input-path" + + result = predictor_async.predict( + input_path=input_location, + ) + + assert result == RETURN_VALUE + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called + assert sagemaker_session.s3_client.get_waiter.called_with("object_exists") + assert sagemaker_session.sagemaker_client.describe_endpoint.not_called + assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called + + def test_predict_async_call_invalid_input(): sagemaker_session = empty_sagemaker_session() predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) From c9bab0defb1f10aec4b38eeea41f3198f1c53eae Mon Sep 17 00:00:00 2001 From: Praneeth Dodda Date: Wed, 17 May 2023 16:04:39 -0500 Subject: [PATCH 2/2] address black-check format issues --- src/sagemaker/predictor_async.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index 5b5ed532be..4c6324a541 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -218,9 +218,7 @@ def _submit_async_request( def _wait_for_output(self, output_path, failure_path, waiter_config): """Retrieve output based on the presense of failure_path.""" if failure_path is not None: - return self._check_output_and_failure_paths( - output_path, failure_path, waiter_config - ) + return self._check_output_and_failure_paths(output_path, failure_path, waiter_config) return self._check_output_path(output_path, waiter_config)