diff --git a/tests/integ/file_system_input_utils.py b/tests/integ/file_system_input_utils.py index deb8ff8569..94e2f8c4d3 100644 --- a/tests/integ/file_system_input_utils.py +++ b/tests/integ/file_system_input_utils.py @@ -18,12 +18,12 @@ from os import path import stat import tempfile -import time import uuid from botocore.exceptions import ClientError from fabric import Connection +from tests.integ.retry import retries from tests.integ.vpc_test_utils import check_or_create_vpc_resources_efs_fsx VPC_NAME = "sagemaker-efs-fsx-vpc" @@ -36,7 +36,6 @@ AMI_ID = "ami-082b5a644766e0e6f" MIN_COUNT = 1 MAX_COUNT = 1 -TIME_SLEEP_DURATION = 10 RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data") MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist") @@ -307,21 +306,6 @@ def _instance_profile_exists(sagemaker_session): return True -def retries(max_retry_count, exception_message_prefix): - current_retry_count = 0 - while current_retry_count <= max_retry_count: - yield current_retry_count - - current_retry_count += 1 - time.sleep(TIME_SLEEP_DURATION) - - raise Exception( - "{} has reached the maximum retry count {}".format( - exception_message_prefix, max_retry_count - ) - ) - - def tear_down(sagemaker_session, fs_resources): fsx_client = sagemaker_session.boto_session.client("fsx") file_system_fsx_id = fs_resources.file_system_fsx_id diff --git a/tests/integ/retry.py b/tests/integ/retry.py new file mode 100644 index 0000000000..7d69f4c6a1 --- /dev/null +++ b/tests/integ/retry.py @@ -0,0 +1,29 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import time + +DEFAULT_SLEEP_TIME_SECONDS = 10 + + +def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS): + for i in range(max_retry_count): + yield i + time.sleep(seconds_to_sleep) + + raise Exception( + "{} has reached the maximum retry count {}".format( + exception_message_prefix, max_retry_count + ) + ) diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 455d36d61b..83445d944d 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -14,7 +14,6 @@ import json import os -import time import pytest from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES @@ -30,6 +29,7 @@ from sagemaker.predictor import RealTimePredictor, json_serializer from sagemaker.sparkml.model import SparkMLModel from sagemaker.utils import sagemaker_timestamp +from tests.integ.retry import retries SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model") XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model") @@ -190,16 +190,11 @@ def test_inference_pipeline_model_deploy_with_update_endpoint( model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name) # Wait for endpoint to finish updating - max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout - current_retry_count = 0 - while current_retry_count <= max_retry_count: - if current_retry_count >= max_retry_count: - raise Exception("Endpoint status not 'InService' within expected timeout.") - time.sleep(30) + # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout + for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30): new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint( EndpointName=endpoint_name ) - current_retry_count += 1 if new_endpoint["EndpointStatus"] == "InService": break diff --git a/tests/integ/test_mxnet_train.py b/tests/integ/test_mxnet_train.py index 398d2cab9a..7a99aefab9 100644 --- a/tests/integ/test_mxnet_train.py +++ b/tests/integ/test_mxnet_train.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -24,6 +24,7 @@ from sagemaker.utils import sagemaker_timestamp from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES from tests.integ.kms_utils import get_or_create_kms_key +from tests.integ.retry import retries from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name @@ -182,16 +183,11 @@ def test_deploy_model_with_update_endpoint( model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name) # Wait for endpoint to finish updating - max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout - current_retry_count = 0 - while current_retry_count <= max_retry_count: - if current_retry_count >= max_retry_count: - raise Exception("Endpoint status not 'InService' within expected timeout.") - time.sleep(30) + # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout + for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30): new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint( EndpointName=endpoint_name ) - current_retry_count += 1 if new_endpoint["EndpointStatus"] == "InService": break diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index 3f20dd8a26..b895fa6f62 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -23,6 +23,7 @@ import tests.integ from tests.integ import timeout +from tests.integ.retry import retries from tests.integ.s3_utils import assert_s3_files_exist ROLE = "SageMakerRole" @@ -199,15 +200,13 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type): assert expected_result == result -def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=15): - actual_tags = None - for _ in range(retries): +def _assert_tags_match(sagemaker_client, resource_arn, tags, retry_count=15): + # endpoint and training tags might take minutes to propagate. + for _ in retries(retry_count, "Getting endpoint tags", seconds_to_sleep=30): actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"] if actual_tags: break - else: - # endpoint and training tags might take minutes to propagate. Sleeping. - time.sleep(30) + assert actual_tags == tags