Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/sagemaker/async_inference/async_inference_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,31 @@ def get_result(
return self._result

def _get_result_from_s3(self, output_path, failure_path):
"""Retrieve output based on the presense of failure_path"""
if failure_path is not None:
return self._get_result_from_s3_output_failure_paths(output_path, failure_path)

return self._get_result_from_s3_output_path(output_path)

def _get_result_from_s3_output_path(self, output_path):
"""Get inference result from the output Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
return self.predictor_async.predictor._handle_response(response)
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise ObjectNotExistedError(
message="Inference could still be running",
output_path=output_path,
)
raise UnexpectedClientError(
message=ex.response["Error"]["Message"],
)

def _get_result_from_s3_output_failure_paths(self, output_path, failure_path):
"""Get inference result from the output & failure Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
return self.predictor_async.predictor._handle_response(response)
Expand Down
33 changes: 31 additions & 2 deletions src/sagemaker/predictor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def predict(
self._input_path = input_path
response = self._submit_async_request(input_path, initial_args, inference_id)
output_location = response["OutputLocation"]
failure_location = response["FailureLocation"]
failure_location = response.get("FailureLocation")
result = self._wait_for_output(
output_path=output_location, failure_path=failure_location, waiter_config=waiter_config
)
Expand Down Expand Up @@ -145,7 +145,7 @@ def predict_async(
self._input_path = input_path
response = self._submit_async_request(input_path, initial_args, inference_id)
output_location = response["OutputLocation"]
failure_location = response["FailureLocation"]
failure_location = response.get("FailureLocation")
response_async = AsyncInferenceResponse(
predictor_async=self,
output_path=output_location,
Expand Down Expand Up @@ -216,6 +216,35 @@ def _submit_async_request(
return response

def _wait_for_output(self, output_path, failure_path, waiter_config):
"""Retrieve output based on the presense of failure_path."""
if failure_path is not None:
return self._check_output_and_failure_locations(
output_path, failure_path, waiter_config
)

return self._check_output_location(output_path, waiter_config)

def _check_output_location(self, output_path, waiter_config):
"""Check the Amazon S3 output path for the output.

Periodically check Amazon S3 output path for async inference result.
Timeout automatically after max attempts reached
"""
bucket, key = parse_s3_url(output_path)
s3_waiter = self.s3_client.get_waiter("object_exists")
try:
s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict())
except WaiterError:
raise PollingTimeoutError(
message="Inference could still be running",
output_path=output_path,
seconds=waiter_config.delay * waiter_config.max_attempts,
)
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
result = self.predictor._handle_response(response=s3_object)
return result

def _check_output_and_failure_locations(self, output_path, failure_path, waiter_config):
"""Check the Amazon S3 output path for the output.

This method waits for either the output file or the failure file to be found on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ def empty_s3_client():
return s3_client


def empty_s3_client_to_verify_exceptions_for_null_failure_path():
"""
Returns a mocked S3 client with the `get_object` method overridden
to raise different exceptions based on the input.

Exceptions raised:
- `ObjectNotExistedError`
- `UnexpectedClientError`

"""
s3_client = Mock(name="s3-client")

object_error = ObjectNotExistedError("Inference could still be running", DEFAULT_OUTPUT_PATH)

unexpected_error = UnexpectedClientError("some error message")

s3_client.get_object = Mock(
name="get_object",
side_effect=[
object_error,
unexpected_error,
],
)
return s3_client


def mock_s3_client():
"""
This function returns a mocked S3 client object that has a get_object method with a side_effect
Expand Down Expand Up @@ -172,3 +198,47 @@ def test_get_result_verify_exceptions():
UnexpectedClientError, match="Encountered unexpected client error: some error message"
):
async_inference_response.get_result()


def test_get_result_with_null_failure_path():
"""
verifies that the result is returned correctly if no errors occur.
"""
# Initialize AsyncInferenceResponse
predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
predictor_async.s3_client = mock_s3_client()
async_inference_response = AsyncInferenceResponse(
output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, failure_path=None
)

result = async_inference_response.get_result()
assert async_inference_response._result == result
assert result == RETURN_VALUE


def test_get_result_verify_exceptions_with_null_failure_path():
"""
Verifies that get_result method raises the expected exception
when an error occurs while fetching the result.
"""
# Initialize AsyncInferenceResponse
predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
predictor_async.s3_client = empty_s3_client_to_verify_exceptions_for_null_failure_path()
async_inference_response = AsyncInferenceResponse(
output_path=DEFAULT_OUTPUT_PATH,
predictor_async=predictor_async,
failure_path=None,
)

# Test ObjectNotExistedError
with pytest.raises(
ObjectNotExistedError,
match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. Inference could still be running",
):
async_inference_response.get_result()

# Test UnexpectedClientError
with pytest.raises(
UnexpectedClientError, match="Encountered unexpected client error: some error message"
):
async_inference_response.get_result()
114 changes: 113 additions & 1 deletion tests/unit/test_predictor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,37 @@ def empty_sagemaker_session():
return ims


def empty_sagemaker_session_with_null_failure_path():
ims = Mock(name="sagemaker_session")
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime")
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)

ims.sagemaker_runtime_client.invoke_endpoint_async = Mock(
name="invoke_endpoint_async",
return_value={
"OutputLocation": ASYNC_OUTPUT_LOCATION,
},
)

polling_timeout_error = PollingTimeoutError(
message="Inference could still be running",
output_path=ASYNC_OUTPUT_LOCATION,
seconds=DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts,
)

ims.s3_client = Mock(name="s3_client")
ims.s3_client.get_object = Mock(
name="get_object",
side_effect=[polling_timeout_error],
)

ims.s3_client.put_object = Mock(name="put_object")

return ims


def empty_predictor():
predictor = Mock(name="predictor")
predictor.update_endpoint = Mock(name="update_endpoint")
Expand Down Expand Up @@ -161,6 +192,31 @@ def test_async_predict_call_with_data_and_input_path():
assert result.failure_path == ASYNC_FAILURE_LOCATION


def test_async_predict_call_with_data_and_input_and_null_failure_path():
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
predictor_async.name = ASYNC_PREDICTOR
data = DUMMY_DATA

result = predictor_async.predict_async(data=data, input_path=ASYNC_INPUT_LOCATION)
assert sagemaker_session.s3_client.put_object.called

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called

expected_request_args = {
"Accept": DEFAULT_ACCEPT,
"InputLocation": ASYNC_INPUT_LOCATION,
"EndpointName": ENDPOINT,
}

call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args
assert kwargs == expected_request_args
assert result.output_path == ASYNC_OUTPUT_LOCATION
assert result.failure_path is None


def test_async_predict_call_verify_exceptions():
sagemaker_session = empty_sagemaker_session()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
Expand All @@ -185,7 +241,27 @@ def test_async_predict_call_verify_exceptions():
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_async_predict_call_pass_through_success():
def test_async_predict_call_verify_exceptions_with_null_failure_path():
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))

input_location = "s3://some-input-path"

with pytest.raises(
PollingTimeoutError,
match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for "
f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}"
f" seconds. Inference could still be running",
):
predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG)

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
assert sagemaker_session.s3_client.get_object.called
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_async_predict_call_pass_through_output_failure_paths():
sagemaker_session = empty_sagemaker_session()

response_body = Mock("body")
Expand Down Expand Up @@ -222,6 +298,42 @@ def test_async_predict_call_pass_through_success():
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_async_predict_call_pass_through_with_null_failure_path():
sagemaker_session = empty_sagemaker_session_with_null_failure_path()

response_body = Mock("body")
response_body.read = Mock("read", return_value=RETURN_VALUE)
response_body.close = Mock("close", return_value=None)

sagemaker_session.s3_client = Mock(name="s3_client")
sagemaker_session.s3_client.get_object = Mock(
name="get_object",
return_value={"Body": response_body},
)
sagemaker_session.s3_client.put_object = Mock(name="put_object")

predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))

sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async = Mock(
name="invoke_endpoint_async",
return_value={
"OutputLocation": ASYNC_OUTPUT_LOCATION,
},
)

input_location = "s3://some-input-path"

result = predictor_async.predict(
input_path=input_location,
)

assert result == RETURN_VALUE
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
assert sagemaker_session.s3_client.get_waiter.called_with("object_exists")
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called


def test_predict_async_call_invalid_input():
sagemaker_session = empty_sagemaker_session()
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
Expand Down