diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 882d71d45d..979725b2f3 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -235,10 +235,14 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): sagemaker.local.utils.recursive_copy(host_dir, output_artifacts) # Tar Artifacts -> model.tar.gz and output.tar.gz - model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)] - output_files = [os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)] - sagemaker.utils.create_tar_file(model_files, os.path.join(compressed_artifacts, 'model.tar.gz')) - sagemaker.utils.create_tar_file(output_files, os.path.join(compressed_artifacts, 'output.tar.gz')) + model_files = [os.path.join(model_artifacts, name) for name in + os.listdir(model_artifacts)] + output_files = [os.path.join(output_artifacts, name) for name in + os.listdir(output_artifacts)] + sagemaker.utils.create_tar_file(model_files, + os.path.join(compressed_artifacts, 'model.tar.gz')) + sagemaker.utils.create_tar_file(output_files, + os.path.join(compressed_artifacts, 'output.tar.gz')) if output_data_config['S3OutputPath'] == '': output_data = 'file://%s' % compressed_artifacts diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 49346bead8..80b809363f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -37,7 +37,7 @@ class Model(object): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, name=None, vpc_config=None, - sagemaker_session=None): + sagemaker_session=None, enable_network_isolation=False): """Initialize an SageMaker ``Model``. Args: @@ -58,6 +58,9 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n * 'SecurityGroupIds' (list[str]): List of security group ids. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + enable_network_isolation (Boolean): Default False. if True, enables network isolation in the endpoint, + isolating the model container. No inbound or outbound network calls can be made to or from the + model container. """ self.model_data = model_data self.image = image @@ -69,6 +72,7 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n self.sagemaker_session = sagemaker_session self._model_name = None self._is_compiled_model = False + self._enable_network_isolation = enable_network_isolation def prepare_container_def(self, instance_type, accelerator_type=None): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type. @@ -92,7 +96,7 @@ def enable_network_isolation(self): Returns: bool: If network isolation should be enabled or not. """ - return False + return self._enable_network_isolation def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=None): """Create a SageMaker Model Entity diff --git a/src/sagemaker/tensorflow/serving.py b/src/sagemaker/tensorflow/serving.py index 9e141e5728..a680f2df30 100644 --- a/src/sagemaker/tensorflow/serving.py +++ b/src/sagemaker/tensorflow/serving.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import logging + import sagemaker from sagemaker.content_types import CONTENT_TYPE_JSON from sagemaker.fw_utils import create_image_uri @@ -88,7 +89,7 @@ def predict(self, data, initial_args=None): return super(Predictor, self).predict(data, args) -class Model(sagemaker.Model): +class Model(sagemaker.model.FrameworkModel): FRAMEWORK_NAME = 'tensorflow-serving' LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL' LOG_LEVEL_MAP = { @@ -99,7 +100,7 @@ class Model(sagemaker.Model): logging.CRITICAL: 'crit', } - def __init__(self, model_data, role, image=None, framework_version=TF_VERSION, + def __init__(self, model_data, role, entry_point=None, image=None, framework_version=TF_VERSION, container_log_level=None, predictor_cls=Predictor, **kwargs): """Initialize a Model. @@ -118,14 +119,23 @@ def __init__(self, model_data, role, image=None, framework_version=TF_VERSION, **kwargs: Keyword arguments passed to the ``Model`` initializer. """ super(Model, self).__init__(model_data=model_data, role=role, image=image, - predictor_cls=predictor_cls, **kwargs) + predictor_cls=predictor_cls, entry_point=entry_point, **kwargs) self._framework_version = framework_version self._container_log_level = container_log_level def prepare_container_def(self, instance_type, accelerator_type=None): image = self._get_image_uri(instance_type, accelerator_type) env = self._get_container_env() - return sagemaker.container_def(image, self.model_data, env) + + if self.entry_point: + model_data = sagemaker.utils.repack_model(self.entry_point, + self.source_dir, + self.model_data, + self.sagemaker_session) + else: + model_data = self.model_data + + return sagemaker.container_def(image, model_data, env) def _get_container_env(self): if not self._container_log_level: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 417619e38b..5f35e4f259 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -12,10 +12,12 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import contextlib import errno import os import random import re +import shutil import sys import tarfile import tempfile @@ -23,9 +25,11 @@ from datetime import datetime from functools import wraps +from six.moves.urllib import parse import six +import sagemaker ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' @@ -258,13 +262,10 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): def create_tar_file(source_files, target=None): """Create a tar file containing all the source_files - Args: source_files (List[str]): List of file paths that will be contained in the tar file - Returns: (str): path to created tar file - """ if target: filename = target @@ -278,6 +279,92 @@ def create_tar_file(source_files, target=None): return filename +@contextlib.contextmanager +def _tmpdir(suffix='', prefix='tmp'): + """Create a temporary directory with a context manager. The file is deleted when the context exits. + + The prefix, suffix, and dir arguments are the same as for mkstemp(). + + Args: + suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no + suffix. + prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise, + a default prefix is used. + dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is + used. + Returns: + str: path to the directory + """ + tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None) + yield tmp + shutil.rmtree(tmp) + + +def repack_model(inference_script, source_directory, model_uri, sagemaker_session): + """Unpack model tarball and creates a new model tarball with the provided code script. + + This function does the following: + - uncompresses model tarball from S3 or local system into a temp folder + - replaces the inference code from the model with the new code provided + - compresses the new model tarball and saves it in S3 or local file system + + Args: + inference_script (str): path or basename of the inference script that will be packed into the model + source_directory (str): path including all the files that will be packed into the model + model_uri (str): S3 or file system location of the original model tar + sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. + + Returns: + str: path to the new packed model + """ + new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() + + with _tmpdir() as tmp: + tmp_model_dir = os.path.join(tmp, 'model') + os.mkdir(tmp_model_dir) + + model_from_s3 = model_uri.startswith('s3://') + if model_from_s3: + local_model_path = os.path.join(tmp, 'tar_file') + download_file_from_url(model_uri, local_model_path, sagemaker_session) + + new_model_path = os.path.join(tmp, new_model_name) + else: + local_model_path = model_uri.replace('file://', '') + new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name) + + with tarfile.open(name=local_model_path, mode='r:gz') as t: + t.extractall(path=tmp_model_dir) + + code_dir = os.path.join(tmp_model_dir, 'code') + if os.path.exists(code_dir): + shutil.rmtree(code_dir, ignore_errors=True) + + dirname = source_directory if source_directory else os.path.dirname(inference_script) + + shutil.copytree(dirname, code_dir) + + with tarfile.open(new_model_path, mode='w:gz') as t: + t.add(tmp_model_dir, arcname=os.path.sep) + + if model_from_s3: + url = parse.urlparse(model_uri) + bucket, key = url.netloc, url.path.lstrip('/') + new_key = key.replace(os.path.basename(key), new_model_name) + + sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path) + return 's3://%s/%s' % (bucket, new_key) + else: + return 'file://%s' % new_model_path + + +def download_file_from_url(url, dst, sagemaker_session): + url = parse.urlparse(url) + bucket, key = url.netloc, url.path.lstrip('/') + + download_file(bucket, key, dst, sagemaker_session) + + def download_file(bucket_name, path, target, sagemaker_session): """Download a Single File from S3 into a local path diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt b/tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt new file mode 100644 index 0000000000..f9ff036688 --- /dev/null +++ b/tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/saved_model.pb b/tests/data/tfs/tfs-test-model-with-inference/00000123/saved_model.pb new file mode 100644 index 0000000000..71ac858241 Binary files /dev/null and b/tests/data/tfs/tfs-test-model-with-inference/00000123/saved_model.pb differ diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001 b/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000..74cf86632b Binary files /dev/null and b/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001 differ diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index b/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index new file mode 100644 index 0000000000..ac030a9d40 Binary files /dev/null and b/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index differ diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py new file mode 100644 index 0000000000..507d0c44f3 --- /dev/null +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -0,0 +1,26 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json + + +def input_handler(data, context): + data = json.loads(data.read().decode('utf-8')) + new_values = [x + 1 for x in data['instances']] + dumps = json.dumps({'instances': new_values}) + return dumps + + +def output_handler(data, context): + response_content_type = context.accept_header + prediction = data.content + return prediction, response_content_type diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index ea38ecd92c..05e0725d5c 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -12,7 +12,11 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import tarfile + import botocore.exceptions +import os + import pytest import sagemaker import sagemaker.predictor @@ -36,9 +40,10 @@ def instance_type(request): def tfs_predictor(instance_type, sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') model_data = sagemaker_session.upload_data( - path='tests/data/tensorflow-serving-test-model.tar.gz', + path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), key_prefix='tensorflow-serving/models') - with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): model = Model(model_data=model_data, role='SageMakerRole', framework_version=tf_full_version, sagemaker_session=sagemaker_session) @@ -46,18 +51,76 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version): yield predictor +def tar_dir(directory, tmpdir): + target = os.path.join(str(tmpdir), 'model.tar.gz') + + with tarfile.open(target, mode='w:gz') as t: + t.add(directory, arcname=os.path.sep) + return target + + +@pytest.fixture +def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, + sagemaker_session, + tf_full_version, + tmpdir): + endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') + + model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference'), + tmpdir) + + model_data = sagemaker_session.upload_data( + path=model_tar, + key_prefix='tensorflow-serving/models') + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): + model = Model(model_data=model_data, + role='SageMakerRole', + framework_version=tf_full_version, + sagemaker_session=sagemaker_session) + predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) + yield predictor + + +@pytest.fixture(scope='module') +def tfs_predictor_with_model_and_entry_point_separated(instance_type, + sagemaker_session, tf_full_version): + endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') + + model_data = sagemaker_session.upload_data( + path=os.path.join(tests.integ.DATA_DIR, + 'tensorflow-serving-test-model.tar.gz'), + key_prefix='tensorflow-serving/models') + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): + entry_point = os.path.join(tests.integ.DATA_DIR, + 'tfs/tfs-test-model-with-inference/code/inference.py') + model = Model(entry_point=entry_point, + model_data=model_data, + role='SageMakerRole', + framework_version=tf_full_version, + sagemaker_session=sagemaker_session) + predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) + yield predictor + + @pytest.fixture(scope='module') def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") instance_type = 'ml.c4.large' accelerator_type = 'ml.eia1.medium' - model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz', - key_prefix='tensorflow-serving/models') - with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + model_data = sagemaker_session.upload_data( + path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), + key_prefix='tensorflow-serving/models') + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): model = Model(model_data=model_data, role='SageMakerRole', framework_version=tf_full_version, sagemaker_session=sagemaker_session) - predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, accelerator_type=accelerator_type) + predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, + accelerator_type=accelerator_type) yield predictor @@ -81,6 +144,23 @@ def test_predict_with_accelerator(tfs_predictor_with_accelerator): assert expected_result == result +def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_tar): + input_data = {'instances': [1.0, 2.0, 5.0]} + expected_result = {'predictions': [4.0, 4.5, 6.0]} + + result = tfs_predictor_with_model_and_entry_point_same_tar.predict(input_data) + assert expected_result == result + + +def test_predict_with_model_and_entry_point_separated( + tfs_predictor_with_model_and_entry_point_separated): + input_data = {'instances': [1.0, 2.0, 5.0]} + expected_result = {'predictions': [4.0, 4.5, 6.0]} + + result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data) + assert expected_result == result + + def test_predict_generic_json(tfs_predictor): input_data = [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]] expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index 62ddf05f74..c00254eae0 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -154,7 +154,8 @@ } -def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_supported_input_mode_with_valid_input_types(session): # verify that the Estimator verifies the # input mode that an Algorithm supports. @@ -178,7 +179,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) # Creating a File mode Estimator with a File mode algorithm should work AlgorithmEstimator( @@ -186,7 +187,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) pipe_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) @@ -209,7 +210,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) # Creating a Pipe mode Estimator with a Pipe mode algorithm should work. AlgorithmEstimator( @@ -218,7 +219,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session train_instance_type='ml.m4.xlarge', train_instance_count=1, input_mode='Pipe', - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) any_input_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) @@ -241,7 +242,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=any_input_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=any_input_algo) # Creating a File mode Estimator with an algorithm that supports both input modes # should work. @@ -250,11 +251,12 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_supported_input_mode_with_bad_input_types(session): # verify that the Estimator verifies raises exceptions when # attempting to train with an incorrect input type @@ -278,7 +280,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) # Creating a Pipe mode Estimator with a File mode algorithm should fail. with pytest.raises(ValueError): @@ -288,7 +290,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): train_instance_type='ml.m4.xlarge', train_instance_count=1, input_mode='Pipe', - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) pipe_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) @@ -311,7 +313,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) # Creating a File mode Estimator with a Pipe mode algorithm should fail. with pytest.raises(ValueError): @@ -320,12 +322,13 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) @patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -def test_algorithm_trainining_channels_with_expected_channels(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_trainining_channels_with_expected_channels(session): training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels['TrainingSpecification']['TrainingChannels'] = [ @@ -347,14 +350,14 @@ def test_algorithm_trainining_channels_with_expected_channels(sagemaker_session) }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) + session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # Pass training and validation channels. This should work @@ -365,7 +368,8 @@ def test_algorithm_trainining_channels_with_expected_channels(sagemaker_session) @patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -def test_algorithm_trainining_channels_with_invalid_channels(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_trainining_channels_with_invalid_channels(session): training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels['TrainingSpecification']['TrainingChannels'] = [ @@ -387,14 +391,14 @@ def test_algorithm_trainining_channels_with_invalid_channels(sagemaker_session): }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) + session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # Passing only validation should fail as training is required. @@ -406,7 +410,8 @@ def test_algorithm_trainining_channels_with_invalid_channels(sagemaker_session): estimator.fit({'training': 's3://some/data', 'training2': 's3://some/other/data'}) -def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_train_instance_types_valid_instance_types(session): describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) train_instance_types = ['ml.m4.xlarge', 'ml.m5.2xlarge'] @@ -414,7 +419,7 @@ def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): 'SupportedTrainingInstanceTypes' ] = train_instance_types - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=describe_algo_response ) @@ -423,7 +428,7 @@ def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) AlgorithmEstimator( @@ -431,11 +436,12 @@ def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m5.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_train_instance_types_invalid_instance_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_train_instance_types_invalid_instance_types(session): describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) train_instance_types = ['ml.m4.xlarge', 'ml.m5.2xlarge'] @@ -443,7 +449,7 @@ def test_algorithm_train_instance_types_invalid_instance_types(sagemaker_session 'SupportedTrainingInstanceTypes' ] = train_instance_types - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=describe_algo_response ) @@ -454,18 +460,19 @@ def test_algorithm_train_instance_types_invalid_instance_types(sagemaker_session role='SageMakerRole', train_instance_type='ml.m4.8xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_distributed_training_validation(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_distributed_training_validation(session): distributed_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) distributed_algo['TrainingSpecification']['SupportsDistributedTraining'] = True single_instance_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) single_instance_algo['TrainingSpecification']['SupportsDistributedTraining'] = False - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=distributed_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=distributed_algo) # Distributed training should work for Distributed and Single instance. AlgorithmEstimator( @@ -473,7 +480,7 @@ def test_algorithm_distributed_training_validation(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) AlgorithmEstimator( @@ -481,10 +488,10 @@ def test_algorithm_distributed_training_validation(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=2, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=single_instance_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=single_instance_algo) # distributed training on a single instance algorithm should fail. with pytest.raises(ValueError): @@ -493,11 +500,12 @@ def test_algorithm_distributed_training_validation(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m5.2xlarge', train_instance_count=2, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_hyperparameter_integer_range_valid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_integer_range_valid_range(session): hyperparameters = [ { 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', @@ -515,21 +523,22 @@ def test_algorithm_hyperparameter_integer_range_valid_range(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) estimator.set_hyperparameters(max_leaf_nodes=1) estimator.set_hyperparameters(max_leaf_nodes=100000) -def test_algorithm_hyperparameter_integer_range_invalid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_integer_range_invalid_range(session): hyperparameters = [ { 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', @@ -547,14 +556,14 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(sagemaker_session) some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) with pytest.raises(ValueError): @@ -564,7 +573,8 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(sagemaker_session) estimator.set_hyperparameters(max_leaf_nodes=100001) -def test_algorithm_hyperparameter_continuous_range_valid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_continuous_range_valid_range(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -582,14 +592,14 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(sagemaker_session some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) estimator.set_hyperparameters(max_leaf_nodes=0) @@ -598,7 +608,8 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(sagemaker_session estimator.set_hyperparameters(max_leaf_nodes=1) -def test_algorithm_hyperparameter_continuous_range_invalid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_continuous_range_invalid_range(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -616,14 +627,14 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(sagemaker_sessi some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) with pytest.raises(ValueError): @@ -633,7 +644,8 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(sagemaker_sessi estimator.set_hyperparameters(max_leaf_nodes=-0.1) -def test_algorithm_hyperparameter_categorical_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_categorical_range(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -649,14 +661,14 @@ def test_algorithm_hyperparameter_categorical_range(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) estimator.set_hyperparameters(hp1='MXNet') @@ -669,7 +681,8 @@ def test_algorithm_hyperparameter_categorical_range(sagemaker_session): estimator.set_hyperparameters(hp1='MxNET') -def test_algorithm_required_hyperparameters_not_provided(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_required_hyperparameters_not_provided(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -691,14 +704,14 @@ def test_algorithm_required_hyperparameters_not_provided(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # hp1 is required and was not provided @@ -711,8 +724,9 @@ def test_algorithm_required_hyperparameters_not_provided(sagemaker_session): estimator.fit({'training': 's3://some/place'}) +@patch('sagemaker.Session') @patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -def test_algorithm_required_hyperparameters_are_provided(sagemaker_session): +def test_algorithm_required_hyperparameters_are_provided(session): hyperparameters = [ { 'Description': 'A categorical hyperparameter', @@ -741,21 +755,22 @@ def test_algorithm_required_hyperparameters_are_provided(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # All 3 Hyperparameters are provided estimator.set_hyperparameters(hp1='TF', hp2='TF2', free_text_hp1='Hello!') -def test_algorithm_required_free_text_hyperparameter_not_provided(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_required_free_text_hyperparameter_not_provided(session): hyperparameters = [ { 'Name': 'free_text_hp1', @@ -776,14 +791,14 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(sagemaker_sess some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # Calling fit with unset required hyperparameters should fail @@ -796,9 +811,10 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(sagemaker_sess estimator.set_hyperparameters(free_text_hp2='some text') +@patch('sagemaker.Session') @patch('sagemaker.algorithm.AlgorithmEstimator.create_model') -def test_algorithm_create_transformer(create_model, sagemaker_session): - sagemaker_session.sagemaker_client.describe_algorithm = Mock( +def test_algorithm_create_transformer(create_model, session): + session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -806,10 +822,10 @@ def test_algorithm_create_transformer(create_model, sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) - estimator.latest_training_job = _TrainingJob(sagemaker_session, 'some-job-name') + estimator.latest_training_job = _TrainingJob(session, 'some-job-name') model = Mock() model.name = 'my-model' create_model.return_value = model @@ -821,8 +837,9 @@ def test_algorithm_create_transformer(create_model, sagemaker_session): assert transformer.model_name == 'my-model' -def test_algorithm_create_transformer_without_completed_training_job(sagemaker_session): - sagemaker_session.sagemaker_client.describe_algorithm = Mock( +@patch('sagemaker.Session') +def test_algorithm_create_transformer_without_completed_training_job(session): + session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -830,7 +847,7 @@ def test_algorithm_create_transformer_without_completed_training_job(sagemaker_s role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) with pytest.raises(RuntimeError) as error: @@ -839,10 +856,11 @@ def test_algorithm_create_transformer_without_completed_training_job(sagemaker_s @patch('sagemaker.algorithm.AlgorithmEstimator.create_model') -def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_create_transformer_with_product_id(create_model, session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['ProductId'] = 'some-product-id' - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( @@ -850,10 +868,10 @@ def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_se role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) - estimator.latest_training_job = _TrainingJob(sagemaker_session, 'some-job-name') + estimator.latest_training_job = _TrainingJob(session, 'some-job-name') model = Mock() model.name = 'my-model' create_model.return_value = model @@ -862,8 +880,9 @@ def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_se assert transformer.env is None -def test_algorithm_enable_network_isolation_no_product_id(sagemaker_session): - sagemaker_session.sagemaker_client.describe_algorithm = Mock( +@patch('sagemaker.Session') +def test_algorithm_enable_network_isolation_no_product_id(session): + session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -871,17 +890,18 @@ def test_algorithm_enable_network_isolation_no_product_id(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) network_isolation = estimator.enable_network_isolation() assert network_isolation is False -def test_algorithm_enable_network_isolation_with_product_id(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_enable_network_isolation_with_product_id(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['ProductId'] = 'some-product-id' - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( @@ -889,17 +909,18 @@ def test_algorithm_enable_network_isolation_with_product_id(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) network_isolation = estimator.enable_network_isolation() assert network_isolation is True -def test_algorithm_encrypt_inter_container_traffic(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_encrypt_inter_container_traffic(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['encrypt_inter_container_traffic'] = True - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( @@ -907,7 +928,7 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, encrypt_inter_container_traffic=True ) @@ -915,11 +936,12 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session): assert encrypt_inter_container_traffic is True -def test_algorithm_no_required_hyperparameters(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_no_required_hyperparameters(session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) del some_algo['TrainingSpecification']['SupportedHyperParameters'] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) # Calling AlgorithmEstimator() with unset required hyperparameters # should fail if they are required. @@ -929,5 +951,5 @@ def test_algorithm_no_required_hyperparameters(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 265090c870..efd0ad499a 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -14,18 +14,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import shutil +import tarfile from datetime import datetime import os import re import time import pytest -from mock import call, patch, Mock +from mock import call, patch, Mock, MagicMock import sagemaker -from sagemaker.utils import get_config_value, name_from_base,\ - to_str, DeferredError, extract_name_from_job_arn, secondary_training_status_changed,\ - secondary_training_status_message, unique_name_from_base NAME = 'base_name' @@ -44,15 +43,15 @@ def test_get_config_value(): } } - assert get_config_value('local.region_name', config) == 'us-west-2' - assert get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'} + assert sagemaker.utils.get_config_value('local.region_name', config) == 'us-west-2' + assert sagemaker.utils.get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'} - assert get_config_value('does_not.exist', config) is None - assert get_config_value('other.key', None) is None + assert sagemaker.utils.get_config_value('does_not.exist', config) is None + assert sagemaker.utils.get_config_value('other.key', None) is None def test_deferred_error(): - de = DeferredError(ImportError("pretend the import failed")) + de = sagemaker.utils.DeferredError(ImportError("pretend the import failed")) with pytest.raises(ImportError) as _: # noqa: F841 de.something() @@ -61,7 +60,7 @@ def test_bad_import(): try: import pandas_is_not_installed as pd except ImportError as e: - pd = DeferredError(e) + pd = sagemaker.utils.DeferredError(e) assert pd is not None with pytest.raises(ImportError) as _: # noqa: F841 pd.DataFrame() @@ -69,44 +68,44 @@ def test_bad_import(): @patch('sagemaker.utils.sagemaker_timestamp') def test_name_from_base(sagemaker_timestamp): - name_from_base(NAME, short=False) + sagemaker.utils.name_from_base(NAME, short=False) assert sagemaker_timestamp.called_once @patch('sagemaker.utils.sagemaker_short_timestamp') def test_name_from_base_short(sagemaker_short_timestamp): - name_from_base(NAME, short=True) + sagemaker.utils.name_from_base(NAME, short=True) assert sagemaker_short_timestamp.called_once def test_unique_name_from_base(): - assert re.match(r'base-\d{10}-[a-f0-9]{4}', unique_name_from_base('base')) + assert re.match(r'base-\d{10}-[a-f0-9]{4}', sagemaker.utils.unique_name_from_base('base')) def test_unique_name_from_base_truncated(): assert re.match(r'real-\d{10}-[a-f0-9]{4}', - unique_name_from_base('really-long-name', max_length=20)) + sagemaker.utils.unique_name_from_base('really-long-name', max_length=20)) def test_to_str_with_native_string(): value = 'some string' - assert to_str(value) == value + assert sagemaker.utils.to_str(value) == value def test_to_str_with_unicode_string(): value = u'åñøthér strîng' - assert to_str(value) == value + assert sagemaker.utils.to_str(value) == value def test_name_from_tuning_arn(): arn = 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11' - name = extract_name_from_job_arn(arn) + name = sagemaker.utils.extract_name_from_job_arn(arn) assert name == 'resnet-sgd-tuningjob-11-07-34-11' def test_name_from_training_arn(): arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b' - name = extract_name_from_job_arn(arn) + name = sagemaker.utils.extract_name_from_job_arn(arn) assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b' @@ -125,32 +124,33 @@ def test_name_from_training_arn(): def test_secondary_training_status_changed_true(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) assert changed is True def test_secondary_training_status_changed_false(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1) assert changed is False def test_secondary_training_status_changed_prev_missing(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, {}) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, {}) assert changed is True def test_secondary_training_status_changed_prev_none(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, None) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, None) assert changed is True def test_secondary_training_status_changed_current_missing(): - changed = secondary_training_status_changed({}, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed({}, TRAINING_JOB_DESCRIPTION_1) assert changed is False def test_secondary_training_status_changed_empty(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, + TRAINING_JOB_DESCRIPTION_1) assert changed is False @@ -162,7 +162,8 @@ def test_secondary_training_status_message_status_changed(): STATUS, MESSAGE ) - assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, + TRAINING_JOB_DESCRIPTION_EMPTY) == expected def test_secondary_training_status_message_status_not_changed(): @@ -173,7 +174,8 @@ def test_secondary_training_status_message_status_not_changed(): STATUS, MESSAGE ) - assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, + TRAINING_JOB_DESCRIPTION_2) == expected def test_secondary_training_status_message_prev_missing(): @@ -184,7 +186,7 @@ def test_secondary_training_status_message_prev_missing(): STATUS, MESSAGE ) - assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, {}) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, {}) == expected @patch('os.makedirs') @@ -266,20 +268,202 @@ def test_download_file(): @patch('tarfile.open') def test_create_tar_file_with_provided_path(open): - open.return_value = open - open.__enter__ = Mock() - open.__exit__ = Mock(return_value=None) + files = mock_tarfile(open) + file_list = ['/tmp/a', '/tmp/b'] + path = sagemaker.utils.create_tar_file(file_list, target='/my/custom/path.tar.gz') assert path == '/my/custom/path.tar.gz' + assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] -@patch('tarfile.open') -@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path'))) -def test_create_tar_file_with_auto_generated_path(open): +def mock_tarfile(open): open.return_value = open + files = [] + + def add_files(filename, arcname): + files.append([filename, arcname]) + open.__enter__ = Mock() + open.__enter__().add = add_files open.__exit__ = Mock(return_value=None) - file_list = ['/tmp/a', '/tmp/b'] - path = sagemaker.utils.create_tar_file(file_list) + return files + + +@patch('tarfile.open') +@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path'))) +def test_create_tar_file_with_auto_generated_path(open): + files = mock_tarfile(open) + + path = sagemaker.utils.create_tar_file(['/tmp/a', '/tmp/b']) assert path == '/auto/generated/path' + assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] + + +def write_file(path, content): + with open(path, 'a') as f: + f.write(content) + + +def test_repack_model_without_source_dir(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + contents = [model_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + new_model_uri = sagemaker.utils.repack_model(os.path.join(source_dir, 'inference.py'), + None, + model_uri, + sagemaker_session) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + +def test_repack_model_from_s3_saved_model_to_s3(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + contents = [model_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + new_model_uri = sagemaker.utils.repack_model('inference.py', + source_dir, + model_uri, + sagemaker_session) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + +def test_repack_model_from_file_saves_model_to_file(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file([model_path], model_tar_path) + + sagemaker_session = MagicMock() + + file_mode_path = 'file://%s' % model_tar_path + new_model_uri = sagemaker.utils.repack_model('inference.py', + source_dir, + file_mode_path, + sagemaker_session) + + assert os.path.dirname(new_model_uri) == os.path.dirname(file_mode_path) + assert list_tar_files(new_model_uri, tmpdir) == {'/code/inference.py', '/model'} + + +def test_repack_model_with_inference_code_should_replace_the_code(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'new-inference.py') + write_file(script_path, 'inference script') + + old_code_path = os.path.join(tmp, 'code') + os.mkdir(old_code_path) + old_script_path = os.path.join(old_code_path, 'old-inference.py') + write_file(old_script_path, 'old inference script') + contents = [model_path, old_code_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + new_model_uri = sagemaker.utils.repack_model('inference.py', + source_dir, + model_uri, + sagemaker_session) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/new-inference.py', '/model'} + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + +def mock_s3_model_tar(contents, sagemaker_session, tmp): + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file(contents, model_tar_path) + mock_s3_download(sagemaker_session, model_tar_path) + + +def mock_s3_download(sagemaker_session, model_tar_path): + def download_file(_, target): + shutil.copy2(model_tar_path, target) + + sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = download_file + + +def mock_s3_upload(sagemaker_session, tmp): + dst = os.path.join(tmp, 'dst') + + class MockS3Object(object): + + def __init__(self, bucket, key): + self.bucket = bucket + self.key = key + + def upload_file(self, target): + shutil.copy2(target, dst) + + sagemaker_session.boto_session.resource().Object = MockS3Object + return dst + + +def list_tar_files(tar_ball, tmpdir): + tar_ball = tar_ball.replace('file://', '') + startpath = str(tmpdir.ensure('tmp', dir=True)) + + with tarfile.open(name=tar_ball, mode='r:gz') as t: + t.extractall(path=startpath) + + def walk(): + for root, dirs, files in os.walk(startpath): + path = root.replace(startpath, '') + for f in files: + yield '%s/%s' % (path, f) + + result = set(walk()) + return result if result else {}