From ee615634a24a08d133b06a4277d39da53302e616 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 12:03:25 -0700 Subject: [PATCH 01/12] Add sagemaker.utils.repack function --- src/sagemaker/model.py | 7 +- src/sagemaker/tensorflow/serving.py | 18 +- src/sagemaker/utils.py | 105 +++++++++ .../00000123/assets/foo.txt | 1 + .../00000123/saved_model.pb | Bin 0 -> 8658 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 12 bytes .../00000123/variables/variables.index | Bin 0 -> 151 bytes .../code/inference.py | 26 ++ tests/integ/test_tfs.py | 87 ++++++- tests/unit/test_utils.py | 223 +++++++++++++++--- 10 files changed, 427 insertions(+), 40 deletions(-) create mode 100644 tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt create mode 100644 tests/data/tfs/tfs-test-model-with-inference/00000123/saved_model.pb create mode 100644 tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001 create mode 100644 tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index create mode 100644 tests/data/tfs/tfs-test-model-with-inference/code/inference.py diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 49346bead8..e48ade0437 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -37,7 +37,7 @@ class Model(object): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, name=None, vpc_config=None, - sagemaker_session=None): + sagemaker_session=None, enable_network_isolation=False): """Initialize an SageMaker ``Model``. Args: @@ -58,6 +58,8 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n * 'SecurityGroupIds' (list[str]): List of security group ids. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + enable_network_isolation (Boolean): Default False. if True, enables network isolation in the endpoint, + isolating the model container. No inbound or outbound network calls can be made to or from the model container. """ self.model_data = model_data self.image = image @@ -69,6 +71,7 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n self.sagemaker_session = sagemaker_session self._model_name = None self._is_compiled_model = False + self._enable_network_isolation = enable_network_isolation def prepare_container_def(self, instance_type, accelerator_type=None): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type. @@ -92,7 +95,7 @@ def enable_network_isolation(self): Returns: bool: If network isolation should be enabled or not. """ - return False + return self._enable_network_isolation def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=None): """Create a SageMaker Model Entity diff --git a/src/sagemaker/tensorflow/serving.py b/src/sagemaker/tensorflow/serving.py index 9e141e5728..a680f2df30 100644 --- a/src/sagemaker/tensorflow/serving.py +++ b/src/sagemaker/tensorflow/serving.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import logging + import sagemaker from sagemaker.content_types import CONTENT_TYPE_JSON from sagemaker.fw_utils import create_image_uri @@ -88,7 +89,7 @@ def predict(self, data, initial_args=None): return super(Predictor, self).predict(data, args) -class Model(sagemaker.Model): +class Model(sagemaker.model.FrameworkModel): FRAMEWORK_NAME = 'tensorflow-serving' LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL' LOG_LEVEL_MAP = { @@ -99,7 +100,7 @@ class Model(sagemaker.Model): logging.CRITICAL: 'crit', } - def __init__(self, model_data, role, image=None, framework_version=TF_VERSION, + def __init__(self, model_data, role, entry_point=None, image=None, framework_version=TF_VERSION, container_log_level=None, predictor_cls=Predictor, **kwargs): """Initialize a Model. @@ -118,14 +119,23 @@ def __init__(self, model_data, role, image=None, framework_version=TF_VERSION, **kwargs: Keyword arguments passed to the ``Model`` initializer. """ super(Model, self).__init__(model_data=model_data, role=role, image=image, - predictor_cls=predictor_cls, **kwargs) + predictor_cls=predictor_cls, entry_point=entry_point, **kwargs) self._framework_version = framework_version self._container_log_level = container_log_level def prepare_container_def(self, instance_type, accelerator_type=None): image = self._get_image_uri(instance_type, accelerator_type) env = self._get_container_env() - return sagemaker.container_def(image, self.model_data, env) + + if self.entry_point: + model_data = sagemaker.utils.repack_model(self.entry_point, + self.source_dir, + self.model_data, + self.sagemaker_session) + else: + model_data = self.model_data + + return sagemaker.container_def(image, model_data, env) def _get_container_env(self): if not self._container_log_level: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 417619e38b..0b835649c7 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -12,10 +12,12 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import contextlib import errno import os import random import re +import shutil import sys import tarfile import tempfile @@ -23,9 +25,11 @@ from datetime import datetime from functools import wraps +from six.moves.urllib import parse import six +import sagemaker ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' @@ -278,6 +282,107 @@ def create_tar_file(source_files, target=None): return filename +@contextlib.contextmanager +def _tmpdir(suffix='', prefix='tmp', dir=None): # type: (str, str, str) -> None + """Create a temporary directory with a context manager. The file is deleted when the context exits. + + The prefix, suffix, and dir arguments are the same as for mkstemp(). + + Args: + suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no + suffix. + prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise, + a default prefix is used. + dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is + used. + Returns: + str: path to the directory + """ + tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir) + yield tmp + shutil.rmtree(tmp) + + +def repack_model(inference_script, source_directory, model_uri, sagemaker_session): + """Unpack model tarball and creates a new model tarball with the provided code script. + + This function does the following: + - uncompresses model tarball from S3 or local system into a temp folder + - replaces the inference code from the model with the new code provided + - compresses the new model tarball and saves it in S3 or local file system + + Args: + inference_script (str): path or basename of the inference script that will be packed into the model + source_directory (str): path including all the files that will be packed into the model + model_uri (str): S3 or file system location of the original model tar + sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. + + Returns: + str: path to the new packed model + """ + new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() + + with _tmpdir() as tmp: + + tmp_model_dir = os.path.join(tmp, 'model') + os.mkdir(tmp_model_dir) + + model_from_s3 = model_uri.startswith('s3://') + if model_from_s3: + + local_model_uri = os.path.join(tmp, 'tar_file') + download_file_from_url(model_uri, local_model_uri, sagemaker_session) + + new_model_path = os.path.join(tmp, new_model_name) + else: + local_model_uri = model_uri.replace('file://', '') + new_model_path = os.path.join(os.path.dirname(local_model_uri), new_model_name) + + with tarfile.open(name=local_model_uri, mode='r:gz') as t: + t.extractall(path=tmp_model_dir) + + code_dir = os.path.join(tmp_model_dir, 'code') + if os.path.exists(code_dir): + shutil.rmtree(code_dir, ignore_errors=True) + + os.mkdir(code_dir) + + source_files = _list_files(inference_script, source_directory) + + for source_file in source_files: + shutil.copy(source_file, code_dir) + + files_to_compress = [os.path.join(tmp_model_dir, file) + for file in os.listdir(tmp_model_dir)] + + tar_file = sagemaker.utils.create_tar_file(files_to_compress, new_model_path) + + if model_from_s3: + url = parse.urlparse(model_uri) + bucket, key = url.netloc, url.path.lstrip('/') + new_key = key.replace(os.path.basename(key), new_model_name) + + sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(tar_file) + return 's3://%s/%s' % (bucket, new_key) + else: + return 'file://%s' % new_model_path + + +def _list_files(script, directory): + if directory is None: + return [script] + + basedir = directory if directory else os.path.dirname(script) + return [os.path.join(basedir, name) for name in os.listdir(basedir)] + + +def download_file_from_url(url, dst, sagemaker_session): + url = parse.urlparse(url) + bucket, key = url.netloc, url.path.lstrip('/') + + download_file(bucket, key, dst, sagemaker_session) + + def download_file(bucket_name, path, target, sagemaker_session): """Download a Single File from S3 into a local path diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt b/tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt new file mode 100644 index 0000000000..f9ff036688 --- /dev/null +++ b/tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/saved_model.pb b/tests/data/tfs/tfs-test-model-with-inference/00000123/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..71ac858241500dae98fba4ddb4c4e3abe4581ca2 GIT binary patch literal 8658 zcmcgx-*Ven8TYYc%l_;n%}LrirwzN90lMxqcHB0!-pFnhtpas%A#V{pqG@E}|xj(B$!mOIsTl3ywn3zag}-~>vK zeBT}nX%z|{-^cg=KnMjW9vjXP7i93pJqkugfj(YuRK_Hc`U<{kTSmZj|G*e=y3}`F zhvjdO##N{u`CNBg^R+!3Ocwr52;76>V|VBWY!yn1exqm!Asee9b6N`c(09GYGN=`$ z1Z+e3Ba06MJ2(}B+C!902wEKzLZv4XLLcZe?hW`SoyP~0HCVN{!%UuJ$Xy( zu%ez@eBU^70>4vwDIc(FuoBX;hn8(3{mPge+k)kAQK{Ieg|`BGpRw_>+)Rm2uRu+4 z48IKdH7>zeSY;P9MI>eT;Jc7uL&35A;D%uN-VM^#px7ypiq?1sLYi4GY(*j{>1b8b zkvB*P9zpfFW0?E^w+Q#9w{~(THz*X9%cvig@8-aP%Fl88xgPFU+}nH|h{xCDm#Zu%YzAIKCh&dM;KU!s?3y!?U>c z)ONVgf!3hhH+*@mb|b3eS@nY0sKcW}5l=kJuNN4;xF3F0*I*CeMc`pX`7ydOU51hj z0bGdg3u}R(7`A+wEJP+39Q@N7p;IcG{g= zgPps#ca3{@Zegs!Q1{;syo6PwWe@EDU09b+KvfWJR-+J^Z_P$o3jT*~Eg zKWtus6+MjPC>J^#T>`P=I1q!zERZCIl;ztLUu~|&bw$%P*OEGlM_FwCM4)W6!fX>} z9Yze6R;jr$l~FL4;M0V(73%HehE|fX%=NjUQHV zVG=4ssk=n;Wzziq6xx?zy}L%M<^M&0h#$UK8<%&u6`C|v+vPzds;H1FLZL%Qb9quD ziqiO(*ek`tf3}q-8~d!%!3U}r5DftfGYvzK#+l?CgvcostsPUeRD`?$p(l_(>Ch~n zBWLxw_o0Y7={q!`8j_xfl;K(EuSO6Z!|d|yIM6|SwUs{0~);7jcMQ}u9Ke3sz*by_fAt4en)9=bTY$=p_dBv6+mCv7Dv{M%WIO4&XXsqOG#e{C3N+WnteVwg{R)`UB-!w{Fy(vJ8L4M? zlApoGEb25pUtno-vN=+PFHek+soX)dS>)iP6(7pXP)%!Ij|{ioTf@yFYMnxl7(LSO z5wnT&(^QEY6+{J`&RkcDDo(OsiVb^aa@&l!UFxljx#fEJKbO_-v5^-}^+K|u;Z%st zG<3B4ruAoY3 zmc2g+d={E)T0A;qs!sw~Y}U9C)uKT$s` z!f?1Id5>qMi~TT_6=ctWemKW12pqsWDK=+v!o&nlxjdQcfCRoEI5ij69BV`;gWxzC zY@Xt5k+`BWQd40Xc1GDH`{yJK3aKIm-8dy`Mq%>x&FluSl{AjNb5X$ia{>PpzCFe0 zZ9)E5f&=^nP#qX z3BtUub}3s?*(BPtsrwH0AZqJIlnloS>8z_SqXmTAvl{vGW%d9^OTt0rbg|o}s-TSW zUAGB^&pT3|U z!zL6AVjp9Y1KZu=6! z14o>w;oYgYV@#FqLfcmOWebm6@yI%8{uXAfdVGLbqx`_4Q%89&J05IHE}NjKb2eWb NWKMQV2FsA>{{bSxjHUnp literal 0 HcmV?d00001 diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001 b/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..74cf86632b56bb7222096c7d8f59400a4fa64c57 GIT binary patch literal 12 QcmZQzV6bOkaBu)&00qMVKmY&$ literal 0 HcmV?d00001 diff --git a/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index b/tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..ac030a9d4018dd553bb0a736b69cf91b5bb6911a GIT binary patch literal 151 zcmZQzVB=tvV&Y(AVB}8ZU=(7|U@>L0P?u+5=3=9m6 vK+JIPndE&C4dxv9v~U9hBU1{46I>`_)64Jf93gxl0YV`BcSE;IsrzjJ-**?> literal 0 HcmV?d00001 diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py new file mode 100644 index 0000000000..507d0c44f3 --- /dev/null +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -0,0 +1,26 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json + + +def input_handler(data, context): + data = json.loads(data.read().decode('utf-8')) + new_values = [x + 1 for x in data['instances']] + dumps = json.dumps({'instances': new_values}) + return dumps + + +def output_handler(data, context): + response_content_type = context.accept_header + prediction = data.content + return prediction, response_content_type diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index ea38ecd92c..45ead8b1b8 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -12,7 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import botocore.exceptions +import os +import tempfile + import pytest import sagemaker import sagemaker.predictor @@ -36,9 +38,10 @@ def instance_type(request): def tfs_predictor(instance_type, sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') model_data = sagemaker_session.upload_data( - path='tests/data/tensorflow-serving-test-model.tar.gz', + path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), key_prefix='tensorflow-serving/models') - with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): model = Model(model_data=model_data, role='SageMakerRole', framework_version=tf_full_version, sagemaker_session=sagemaker_session) @@ -46,18 +49,71 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version): yield predictor +def tar_dir(directory): + + tmp = tempfile.mkdtemp() + + source_files = [os.path.join(directory, name) for name in os.listdir(directory)] + return sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, 'model.tar.gz')) + + +@pytest.fixture(scope='module') +def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, + sagemaker_session, tf_full_version): + endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') + + model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference')) + + model_data = sagemaker_session.upload_data( + path=model_tar, + key_prefix='tensorflow-serving/models') + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): + model = Model(model_data=model_data, + role='SageMakerRole', + framework_version=tf_full_version, + sagemaker_session=sagemaker_session) + predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) + yield predictor + + +@pytest.fixture(scope='module') +def tfs_predictor_with_model_and_entry_point_separated(instance_type, + sagemaker_session, tf_full_version): + endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') + + model_data = sagemaker_session.upload_data( + path=os.path.join(tests.integ.DATA_DIR, + 'tensorflow-serving-test-model.tar.gz'), + key_prefix='tensorflow-serving/models') + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): + model = Model(entry_point=os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference/code/inference.py'), + model_data=model_data, + role='SageMakerRole', + framework_version=tf_full_version, + sagemaker_session=sagemaker_session) + predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) + yield predictor + + @pytest.fixture(scope='module') def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") instance_type = 'ml.c4.large' accelerator_type = 'ml.eia1.medium' - model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz', - key_prefix='tensorflow-serving/models') - with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + model_data = sagemaker_session.upload_data( + path=os.path.join(tests.integ.DATA_DIR,'tensorflow-serving-test-model.tar.gz'), + key_prefix='tensorflow-serving/models') + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, + sagemaker_session): model = Model(model_data=model_data, role='SageMakerRole', framework_version=tf_full_version, sagemaker_session=sagemaker_session) - predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, accelerator_type=accelerator_type) + predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, + accelerator_type=accelerator_type) yield predictor @@ -81,6 +137,23 @@ def test_predict_with_accelerator(tfs_predictor_with_accelerator): assert expected_result == result +def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_tar): + input_data = {'instances': [1.0, 2.0, 5.0]} + expected_result = {'predictions': [4.0, 4.5, 6.0]} + + result = tfs_predictor_with_model_and_entry_point_same_tar.predict(input_data) + assert expected_result == result + + +def test_predict_with_model_and_entry_point_separated( + tfs_predictor_with_model_and_entry_point_separated): + input_data = {'instances': [1.0, 2.0, 5.0]} + expected_result = {'predictions': [4.0, 4.5, 6.0]} + + result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data) + assert expected_result == result + + def test_predict_generic_json(tfs_predictor): input_data = [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]] expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 265090c870..8b47c3ef73 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -14,18 +14,18 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import contextlib +import shutil +import tarfile from datetime import datetime import os import re import time import pytest -from mock import call, patch, Mock +from mock import call, patch, Mock, MagicMock import sagemaker -from sagemaker.utils import get_config_value, name_from_base,\ - to_str, DeferredError, extract_name_from_job_arn, secondary_training_status_changed,\ - secondary_training_status_message, unique_name_from_base NAME = 'base_name' @@ -44,15 +44,15 @@ def test_get_config_value(): } } - assert get_config_value('local.region_name', config) == 'us-west-2' - assert get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'} + assert sagemaker.utils.get_config_value('local.region_name', config) == 'us-west-2' + assert sagemaker.utils.get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'} - assert get_config_value('does_not.exist', config) is None - assert get_config_value('other.key', None) is None + assert sagemaker.utils.get_config_value('does_not.exist', config) is None + assert sagemaker.utils.get_config_value('other.key', None) is None def test_deferred_error(): - de = DeferredError(ImportError("pretend the import failed")) + de = sagemaker.utils.DeferredError(ImportError("pretend the import failed")) with pytest.raises(ImportError) as _: # noqa: F841 de.something() @@ -61,7 +61,7 @@ def test_bad_import(): try: import pandas_is_not_installed as pd except ImportError as e: - pd = DeferredError(e) + pd = sagemaker.utils.DeferredError(e) assert pd is not None with pytest.raises(ImportError) as _: # noqa: F841 pd.DataFrame() @@ -69,44 +69,44 @@ def test_bad_import(): @patch('sagemaker.utils.sagemaker_timestamp') def test_name_from_base(sagemaker_timestamp): - name_from_base(NAME, short=False) + sagemaker.utils.name_from_base(NAME, short=False) assert sagemaker_timestamp.called_once @patch('sagemaker.utils.sagemaker_short_timestamp') def test_name_from_base_short(sagemaker_short_timestamp): - name_from_base(NAME, short=True) + sagemaker.utils.name_from_base(NAME, short=True) assert sagemaker_short_timestamp.called_once def test_unique_name_from_base(): - assert re.match(r'base-\d{10}-[a-f0-9]{4}', unique_name_from_base('base')) + assert re.match(r'base-\d{10}-[a-f0-9]{4}', sagemaker.utils.unique_name_from_base('base')) def test_unique_name_from_base_truncated(): assert re.match(r'real-\d{10}-[a-f0-9]{4}', - unique_name_from_base('really-long-name', max_length=20)) + sagemaker.utils.unique_name_from_base('really-long-name', max_length=20)) def test_to_str_with_native_string(): value = 'some string' - assert to_str(value) == value + assert sagemaker.utils.to_str(value) == value def test_to_str_with_unicode_string(): value = u'åñøthér strîng' - assert to_str(value) == value + assert sagemaker.utils.to_str(value) == value def test_name_from_tuning_arn(): arn = 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11' - name = extract_name_from_job_arn(arn) + name = sagemaker.utils.extract_name_from_job_arn(arn) assert name == 'resnet-sgd-tuningjob-11-07-34-11' def test_name_from_training_arn(): arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b' - name = extract_name_from_job_arn(arn) + name = sagemaker.utils.extract_name_from_job_arn(arn) assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b' @@ -125,32 +125,32 @@ def test_name_from_training_arn(): def test_secondary_training_status_changed_true(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) assert changed is True def test_secondary_training_status_changed_false(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1) assert changed is False def test_secondary_training_status_changed_prev_missing(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, {}) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, {}) assert changed is True def test_secondary_training_status_changed_prev_none(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, None) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, None) assert changed is True def test_secondary_training_status_changed_current_missing(): - changed = secondary_training_status_changed({}, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed({}, TRAINING_JOB_DESCRIPTION_1) assert changed is False def test_secondary_training_status_changed_empty(): - changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1) assert changed is False @@ -162,7 +162,7 @@ def test_secondary_training_status_message_status_changed(): STATUS, MESSAGE ) - assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected def test_secondary_training_status_message_status_not_changed(): @@ -173,7 +173,7 @@ def test_secondary_training_status_message_status_not_changed(): STATUS, MESSAGE ) - assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == expected def test_secondary_training_status_message_prev_missing(): @@ -184,7 +184,7 @@ def test_secondary_training_status_message_prev_missing(): STATUS, MESSAGE ) - assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, {}) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, {}) == expected @patch('os.makedirs') @@ -283,3 +283,172 @@ def test_create_tar_file_with_auto_generated_path(open): file_list = ['/tmp/a', '/tmp/b'] path = sagemaker.utils.create_tar_file(file_list) assert path == '/auto/generated/path' + + +def write_file(path, content): + with open(path, 'a') as f: + f.write(content) + + +def test_repack_model_without_source_dir(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + contents = [model_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + new_model_uri = sagemaker.utils.repack_model(os.path.join(source_dir, 'inference.py'), + None, + model_uri, + sagemaker_session) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} + assert re.match('^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + +def test_repack_model_from_s3_saved_model_to_s3(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + contents = [model_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + new_model_uri = sagemaker.utils.repack_model('inference.py', + source_dir, + model_uri, + sagemaker_session) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} + assert re.match('^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + +def test_repack_model_from_file_saves_model_to_file(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file([model_path], model_tar_path) + + sagemaker_session = MagicMock() + + file_mode_path = 'file://%s' % model_tar_path + new_model_uri = sagemaker.utils.repack_model('inference.py', + source_dir, + file_mode_path, + sagemaker_session) + + assert os.path.dirname(new_model_uri) == os.path.dirname(file_mode_path) + assert list_tar_files(new_model_uri, tmpdir) == {'/code/inference.py', '/model'} + + +def test_repack_model_with_inference_code_should_replace_the_code(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'new-inference.py') + write_file(script_path, 'inference script') + + old_code_path = os.path.join(tmp, 'code') + os.mkdir(old_code_path) + old_script_path = os.path.join(old_code_path, 'old-inference.py') + write_file(old_script_path, 'old inference script') + contents = [model_path, old_code_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + new_model_uri = sagemaker.utils.repack_model('inference.py', + source_dir, + model_uri, + sagemaker_session) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/new-inference.py', '/model'} + assert re.match('^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + +def mock_s3_model_tar(contents, sagemaker_session, tmp): + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file(contents, model_tar_path) + mock_s3_download(sagemaker_session, model_tar_path) + + +def mock_s3_download(sagemaker_session, model_tar_path): + def download_file(_, target): + shutil.copy2(model_tar_path, target) + + sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = download_file + + +def mock_s3_upload(sagemaker_session, tmp): + dst = os.path.join(tmp, 'dst') + + class MockS3Object(object): + + def __init__(self, bucket, key): + self.bucket = bucket + self.key = key + + def upload_file(self, target): + shutil.copy2(target, dst) + + sagemaker_session.boto_session.resource().Object = MockS3Object + return dst + + +def list_tar_files(tar_ball, tmpdir): + tar_ball = tar_ball.replace('file://', '') + startpath = str(tmpdir.ensure('tmp', dir=True)) + + with tarfile.open(name=tar_ball, mode='r:gz') as t: + t.extractall(path=startpath) + + def walk(): + for root, dirs, files in os.walk(startpath): + path = root.replace(startpath, '') + for f in files: + yield '%s/%s' % (path, f) + + result = set(walk()) + return result if result else {} \ No newline at end of file From 83a1ff0c6e0817f0d85d7a0ba578f090841454e1 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 12:26:15 -0700 Subject: [PATCH 02/12] Fix flake8 --- src/sagemaker/model.py | 3 ++- tests/integ/test_tfs.py | 9 ++++++--- tests/unit/test_utils.py | 18 ++++++++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index e48ade0437..80b809363f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -59,7 +59,8 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. enable_network_isolation (Boolean): Default False. if True, enables network isolation in the endpoint, - isolating the model container. No inbound or outbound network calls can be made to or from the model container. + isolating the model container. No inbound or outbound network calls can be made to or from the + model container. """ self.model_data = model_data self.image = image diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 45ead8b1b8..14d7e726ff 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -59,7 +59,8 @@ def tar_dir(directory): @pytest.fixture(scope='module') def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, - sagemaker_session, tf_full_version): + sagemaker_session, + tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference')) @@ -90,7 +91,9 @@ def tfs_predictor_with_model_and_entry_point_separated(instance_type, with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - model = Model(entry_point=os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference/code/inference.py'), + entry_point = os.path.join(tests.integ.DATA_DIR, + 'tfs/tfs-test-model-with-inference/code/inference.py') + model = Model(entry_point=entry_point, model_data=model_data, role='SageMakerRole', framework_version=tf_full_version, @@ -105,7 +108,7 @@ def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): instance_type = 'ml.c4.large' accelerator_type = 'ml.eia1.medium' model_data = sagemaker_session.upload_data( - path=os.path.join(tests.integ.DATA_DIR,'tensorflow-serving-test-model.tar.gz'), + path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), key_prefix='tensorflow-serving/models') with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 8b47c3ef73..469a7c5fdf 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -14,7 +14,6 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import contextlib import shutil import tarfile from datetime import datetime @@ -150,7 +149,8 @@ def test_secondary_training_status_changed_current_missing(): def test_secondary_training_status_changed_empty(): - changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, + TRAINING_JOB_DESCRIPTION_1) assert changed is False @@ -162,7 +162,8 @@ def test_secondary_training_status_message_status_changed(): STATUS, MESSAGE ) - assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, + TRAINING_JOB_DESCRIPTION_EMPTY) == expected def test_secondary_training_status_message_status_not_changed(): @@ -173,7 +174,8 @@ def test_secondary_training_status_message_status_not_changed(): STATUS, MESSAGE ) - assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == expected + assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, + TRAINING_JOB_DESCRIPTION_2) == expected def test_secondary_training_status_message_prev_missing(): @@ -316,7 +318,7 @@ def test_repack_model_without_source_dir(tmpdir): sagemaker_session) assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} - assert re.match('^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) def test_repack_model_from_s3_saved_model_to_s3(tmpdir): @@ -345,7 +347,7 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir): sagemaker_session) assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} - assert re.match('^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) def test_repack_model_from_file_saves_model_to_file(tmpdir): @@ -405,7 +407,7 @@ def test_repack_model_with_inference_code_should_replace_the_code(tmpdir): sagemaker_session) assert list_tar_files(fake_upload_path, tmpdir) == {'/code/new-inference.py', '/model'} - assert re.match('^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) def mock_s3_model_tar(contents, sagemaker_session, tmp): @@ -451,4 +453,4 @@ def walk(): yield '%s/%s' % (path, f) result = set(walk()) - return result if result else {} \ No newline at end of file + return result if result else {} From 012f16ac30c3fe5b21f59b136ecbb98667f4860d Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 12:31:38 -0700 Subject: [PATCH 03/12] Fix flake8 --- src/sagemaker/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 0b835649c7..4caae4b3a9 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -283,7 +283,7 @@ def create_tar_file(source_files, target=None): @contextlib.contextmanager -def _tmpdir(suffix='', prefix='tmp', dir=None): # type: (str, str, str) -> None +def _tmpdir(suffix='', prefix='tmp'): """Create a temporary directory with a context manager. The file is deleted when the context exits. The prefix, suffix, and dir arguments are the same as for mkstemp(). @@ -298,7 +298,7 @@ def _tmpdir(suffix='', prefix='tmp', dir=None): # type: (str, str, str) -> None Returns: str: path to the directory """ - tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir) + tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None) yield tmp shutil.rmtree(tmp) From c35487820d9542664ed3c79aadec22881a1f5b94 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 12:42:21 -0700 Subject: [PATCH 04/12] Fix flake8 --- tests/integ/test_tfs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 14d7e726ff..e923b17794 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import botocore.exceptions import os import tempfile From fde4d9ffc3d4ff9276de8bbd1f3a1eec88da0e07 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 15:04:28 -0700 Subject: [PATCH 05/12] Handle PR comments --- src/sagemaker/utils.py | 17 +++++++---------- tests/integ/test_tfs.py | 12 ++++++------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 4caae4b3a9..b71feba59b 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -330,27 +330,24 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio model_from_s3 = model_uri.startswith('s3://') if model_from_s3: - local_model_uri = os.path.join(tmp, 'tar_file') - download_file_from_url(model_uri, local_model_uri, sagemaker_session) + local_model_path = os.path.join(tmp, 'tar_file') + download_file_from_url(model_uri, local_model_path, sagemaker_session) new_model_path = os.path.join(tmp, new_model_name) else: - local_model_uri = model_uri.replace('file://', '') - new_model_path = os.path.join(os.path.dirname(local_model_uri), new_model_name) + local_model_path = model_uri.replace('file://', '') + new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name) - with tarfile.open(name=local_model_uri, mode='r:gz') as t: + with tarfile.open(name=local_model_path, mode='r:gz') as t: t.extractall(path=tmp_model_dir) code_dir = os.path.join(tmp_model_dir, 'code') if os.path.exists(code_dir): shutil.rmtree(code_dir, ignore_errors=True) - os.mkdir(code_dir) + dirname = source_directory if source_directory else os.path.dirname(inference_script) - source_files = _list_files(inference_script, source_directory) - - for source_file in source_files: - shutil.copy(source_file, code_dir) + shutil.copytree(dirname, code_dir) files_to_compress = [os.path.join(tmp_model_dir, file) for file in os.listdir(tmp_model_dir)] diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index e923b17794..c4d4f1299f 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -50,21 +50,21 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version): yield predictor -def tar_dir(directory): - - tmp = tempfile.mkdtemp() +def tar_dir(directory, tmpdir): source_files = [os.path.join(directory, name) for name in os.listdir(directory)] - return sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, 'model.tar.gz')) + return sagemaker.utils.create_tar_file(source_files, os.path.join(str(tmpdir), 'model.tar.gz')) @pytest.fixture(scope='module') def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, sagemaker_session, - tf_full_version): + tf_full_version, + tmpdir): endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') - model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference')) + model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference'), + tmpdir) model_data = sagemaker_session.upload_data( path=model_tar, From 3bbc3bb7b7923682f34311c259cabbfc0807820c Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 15:08:47 -0700 Subject: [PATCH 06/12] Fix flake8 --- tests/integ/test_tfs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index c4d4f1299f..e4736702c7 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -14,7 +14,6 @@ import botocore.exceptions import os -import tempfile import pytest import sagemaker From eeec58d1efcf6d4ac0c923f503d261de0c896a50 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Mon, 13 May 2019 16:51:33 -0700 Subject: [PATCH 07/12] Handle PR comments --- src/sagemaker/fw_utils.py | 21 ++++++++-------- src/sagemaker/local/image.py | 9 +++---- src/sagemaker/utils.py | 21 ++++++++++------ tests/integ/test_tfs.py | 4 +-- tests/unit/test_utils.py | 48 ++++++++++++++++++++++++++++++------ 5 files changed, 70 insertions(+), 33 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index daf0b60877..0dbb36e17c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -182,9 +182,16 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, tmp = tempfile.mkdtemp() try: - source_files = _list_files_to_compress(script, directory) + dependencies - tar_file = sagemaker.utils.create_tar_file(source_files, - os.path.join(tmp, _TAR_SOURCE_FILENAME)) + if directory: + source_files = dependencies + dir_files = [directory] + else: + source_files = [script] + dependencies + dir_files = [] + + tar_file = sagemaker.utils.create_tar_file(source_files=source_files, + dir_files=dir_files, + target=os.path.join(tmp, _TAR_SOURCE_FILENAME)) if kms_key: extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key} @@ -198,14 +205,6 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name) -def _list_files_to_compress(script, directory): - if directory is None: - return [script] - - basedir = directory if directory else os.path.dirname(script) - return [os.path.join(basedir, name) for name in os.listdir(basedir)] - - def framework_name_from_image(image_name): """Extract the framework and Python version from the image name. diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 882d71d45d..a05f7dfac2 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -234,11 +234,10 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): elif container_dir == '/opt/ml/output': sagemaker.local.utils.recursive_copy(host_dir, output_artifacts) - # Tar Artifacts -> model.tar.gz and output.tar.gz - model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)] - output_files = [os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)] - sagemaker.utils.create_tar_file(model_files, os.path.join(compressed_artifacts, 'model.tar.gz')) - sagemaker.utils.create_tar_file(output_files, os.path.join(compressed_artifacts, 'output.tar.gz')) + sagemaker.utils.create_tar_file(dir_files=[model_artifacts], + target=os.path.join(compressed_artifacts, 'model.tar.gz')) + sagemaker.utils.create_tar_file(dir_files=[output_artifacts], + target=os.path.join(compressed_artifacts, 'output.tar.gz')) if output_data_config['S3OutputPath'] == '': output_data = 'file://%s' % compressed_artifacts diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index b71feba59b..5b631d5cc5 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -260,11 +260,14 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): obj.download_file(file_path) -def create_tar_file(source_files, target=None): - """Create a tar file containing all the source_files +def create_tar_file(source_files=None, target=None, dir_files=None): + """Create a tar file containing all the source_files and the content of all dir_files Args: source_files (List[str]): List of file paths that will be contained in the tar file + target (str): target path of the tar file + dir_files (List[str]): List of directories which will have their contents copy into + the tar file Returns: (str): path to created tar file @@ -275,10 +278,17 @@ def create_tar_file(source_files, target=None): else: _, filename = tempfile.mkstemp() + dir_files = dir_files or [] + source_files = source_files or [] + with tarfile.open(filename, mode='w:gz') as t: for sf in source_files: # Add all files from the directory into the root of the directory structure of the tar t.add(sf, arcname=os.path.basename(sf)) + + for dir_file in dir_files: + t.add(dir_file, arcname=os.path.sep) + return filename @@ -323,13 +333,11 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() with _tmpdir() as tmp: - tmp_model_dir = os.path.join(tmp, 'model') os.mkdir(tmp_model_dir) model_from_s3 = model_uri.startswith('s3://') if model_from_s3: - local_model_path = os.path.join(tmp, 'tar_file') download_file_from_url(model_uri, local_model_path, sagemaker_session) @@ -349,10 +357,7 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio shutil.copytree(dirname, code_dir) - files_to_compress = [os.path.join(tmp_model_dir, file) - for file in os.listdir(tmp_model_dir)] - - tar_file = sagemaker.utils.create_tar_file(files_to_compress, new_model_path) + tar_file = sagemaker.utils.create_tar_file(dir_files=[tmp_model_dir], target=new_model_path) if model_from_s3: url = parse.urlparse(model_uri) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index e4736702c7..c915f0af4f 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -51,8 +51,8 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version): def tar_dir(directory, tmpdir): - source_files = [os.path.join(directory, name) for name in os.listdir(directory)] - return sagemaker.utils.create_tar_file(source_files, os.path.join(str(tmpdir), 'model.tar.gz')) + return sagemaker.utils.create_tar_file(dir_files=[directory], + target=os.path.join(str(tmpdir), 'model.tar.gz')) @pytest.fixture(scope='module') diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 469a7c5fdf..3dc13c4f81 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -268,23 +268,57 @@ def test_download_file(): @patch('tarfile.open') def test_create_tar_file_with_provided_path(open): - open.return_value = open - open.__enter__ = Mock() - open.__exit__ = Mock(return_value=None) + files = mock_tarfile(open) + file_list = ['/tmp/a', '/tmp/b'] + path = sagemaker.utils.create_tar_file(file_list, target='/my/custom/path.tar.gz') assert path == '/my/custom/path.tar.gz' + assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] @patch('tarfile.open') -@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path'))) -def test_create_tar_file_with_auto_generated_path(open): +def test_create_tar_file_with_directories(open): + files = mock_tarfile(open) + + path = sagemaker.utils.create_tar_file(dir_files=['/tmp/a', '/tmp/b'], + target='/my/custom/path.tar.gz') + assert path == '/my/custom/path.tar.gz' + assert files == [['/tmp/a', '/'], ['/tmp/b', '/']] + + +@patch('tarfile.open') +def test_create_tar_file_with_files_and_directories(open): + files = mock_tarfile(open) + + path = sagemaker.utils.create_tar_file(dir_files=['/tmp/a', '/tmp/b'], + source_files=['/tmp/c', '/tmp/d'], + target='/my/custom/path.tar.gz') + assert path == '/my/custom/path.tar.gz' + assert files == [['/tmp/c', 'c'], ['/tmp/d', 'd'], ['/tmp/a', '/'], ['/tmp/b', '/']] + + +def mock_tarfile(open): open.return_value = open + files = [] + + def add_files(filename, arcname): + files.append([filename, arcname]) + open.__enter__ = Mock() + open.__enter__().add = add_files open.__exit__ = Mock(return_value=None) - file_list = ['/tmp/a', '/tmp/b'] - path = sagemaker.utils.create_tar_file(file_list) + return files + + +@patch('tarfile.open') +@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path'))) +def test_create_tar_file_with_auto_generated_path(open): + files = mock_tarfile(open) + + path = sagemaker.utils.create_tar_file(['/tmp/a', '/tmp/b']) assert path == '/auto/generated/path' + assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] def write_file(path, content): From abd0d7719b55e1baddb43452f506977bcf3a64c3 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 14 May 2019 13:35:49 -0700 Subject: [PATCH 08/12] Fix integ test --- tests/integ/test_tfs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index c915f0af4f..e62adcd3a2 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -55,7 +55,6 @@ def tar_dir(directory, tmpdir): target=os.path.join(str(tmpdir), 'model.tar.gz')) -@pytest.fixture(scope='module') def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, sagemaker_session, tf_full_version, From 1a96e0e0241b60743ca464b8d79c7126aedb4da5 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 14 May 2019 15:14:03 -0700 Subject: [PATCH 09/12] Fix integ test --- tests/integ/test_tfs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index e62adcd3a2..e538b3a1d6 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -55,6 +55,7 @@ def tar_dir(directory, tmpdir): target=os.path.join(str(tmpdir), 'model.tar.gz')) +@pytest.fixture def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, sagemaker_session, tf_full_version, From 82d53661a89b3567fd0c9c4b91148da2d581e80b Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 15 May 2019 10:59:31 -0700 Subject: [PATCH 10/12] Fix PR comments --- src/sagemaker/fw_utils.py | 21 +++++++++++---------- src/sagemaker/local/image.py | 13 +++++++++---- src/sagemaker/utils.py | 22 +++++----------------- tests/integ/test_tfs.py | 8 ++++++-- tests/unit/test_utils.py | 21 --------------------- 5 files changed, 31 insertions(+), 54 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 0dbb36e17c..daf0b60877 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -182,16 +182,9 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, tmp = tempfile.mkdtemp() try: - if directory: - source_files = dependencies - dir_files = [directory] - else: - source_files = [script] + dependencies - dir_files = [] - - tar_file = sagemaker.utils.create_tar_file(source_files=source_files, - dir_files=dir_files, - target=os.path.join(tmp, _TAR_SOURCE_FILENAME)) + source_files = _list_files_to_compress(script, directory) + dependencies + tar_file = sagemaker.utils.create_tar_file(source_files, + os.path.join(tmp, _TAR_SOURCE_FILENAME)) if kms_key: extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key} @@ -205,6 +198,14 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name) +def _list_files_to_compress(script, directory): + if directory is None: + return [script] + + basedir = directory if directory else os.path.dirname(script) + return [os.path.join(basedir, name) for name in os.listdir(basedir)] + + def framework_name_from_image(image_name): """Extract the framework and Python version from the image name. diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index a05f7dfac2..979725b2f3 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -234,10 +234,15 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): elif container_dir == '/opt/ml/output': sagemaker.local.utils.recursive_copy(host_dir, output_artifacts) - sagemaker.utils.create_tar_file(dir_files=[model_artifacts], - target=os.path.join(compressed_artifacts, 'model.tar.gz')) - sagemaker.utils.create_tar_file(dir_files=[output_artifacts], - target=os.path.join(compressed_artifacts, 'output.tar.gz')) + # Tar Artifacts -> model.tar.gz and output.tar.gz + model_files = [os.path.join(model_artifacts, name) for name in + os.listdir(model_artifacts)] + output_files = [os.path.join(output_artifacts, name) for name in + os.listdir(output_artifacts)] + sagemaker.utils.create_tar_file(model_files, + os.path.join(compressed_artifacts, 'model.tar.gz')) + sagemaker.utils.create_tar_file(output_files, + os.path.join(compressed_artifacts, 'output.tar.gz')) if output_data_config['S3OutputPath'] == '': output_data = 'file://%s' % compressed_artifacts diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 5b631d5cc5..2c19864b53 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -260,35 +260,22 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): obj.download_file(file_path) -def create_tar_file(source_files=None, target=None, dir_files=None): - """Create a tar file containing all the source_files and the content of all dir_files - +def create_tar_file(source_files, target=None): + """Create a tar file containing all the source_files Args: source_files (List[str]): List of file paths that will be contained in the tar file - target (str): target path of the tar file - dir_files (List[str]): List of directories which will have their contents copy into - the tar file - Returns: (str): path to created tar file - """ if target: filename = target else: _, filename = tempfile.mkstemp() - dir_files = dir_files or [] - source_files = source_files or [] - with tarfile.open(filename, mode='w:gz') as t: for sf in source_files: # Add all files from the directory into the root of the directory structure of the tar t.add(sf, arcname=os.path.basename(sf)) - - for dir_file in dir_files: - t.add(dir_file, arcname=os.path.sep) - return filename @@ -357,14 +344,15 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio shutil.copytree(dirname, code_dir) - tar_file = sagemaker.utils.create_tar_file(dir_files=[tmp_model_dir], target=new_model_path) + with tarfile.open(new_model_path, mode='w:gz') as t: + t.add(tmp_model_dir, arcname=os.path.sep) if model_from_s3: url = parse.urlparse(model_uri) bucket, key = url.netloc, url.path.lstrip('/') new_key = key.replace(os.path.basename(key), new_model_name) - sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(tar_file) + sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path) return 's3://%s/%s' % (bucket, new_key) else: return 'file://%s' % new_model_path diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index e538b3a1d6..05e0725d5c 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -12,6 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import tarfile + import botocore.exceptions import os @@ -50,9 +52,11 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version): def tar_dir(directory, tmpdir): + target = os.path.join(str(tmpdir), 'model.tar.gz') - return sagemaker.utils.create_tar_file(dir_files=[directory], - target=os.path.join(str(tmpdir), 'model.tar.gz')) + with tarfile.open(target, mode='w:gz') as t: + t.add(directory, arcname=os.path.sep) + return target @pytest.fixture diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 3dc13c4f81..efd0ad499a 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -277,27 +277,6 @@ def test_create_tar_file_with_provided_path(open): assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] -@patch('tarfile.open') -def test_create_tar_file_with_directories(open): - files = mock_tarfile(open) - - path = sagemaker.utils.create_tar_file(dir_files=['/tmp/a', '/tmp/b'], - target='/my/custom/path.tar.gz') - assert path == '/my/custom/path.tar.gz' - assert files == [['/tmp/a', '/'], ['/tmp/b', '/']] - - -@patch('tarfile.open') -def test_create_tar_file_with_files_and_directories(open): - files = mock_tarfile(open) - - path = sagemaker.utils.create_tar_file(dir_files=['/tmp/a', '/tmp/b'], - source_files=['/tmp/c', '/tmp/d'], - target='/my/custom/path.tar.gz') - assert path == '/my/custom/path.tar.gz' - assert files == [['/tmp/c', 'c'], ['/tmp/d', 'd'], ['/tmp/a', '/'], ['/tmp/b', '/']] - - def mock_tarfile(open): open.return_value = open files = [] From 539285467e2cfce84d00d2545d7c7e2a0fddcfb8 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 15 May 2019 13:56:18 -0700 Subject: [PATCH 11/12] Fix PR comments --- src/sagemaker/utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 2c19864b53..5f35e4f259 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -358,14 +358,6 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio return 'file://%s' % new_model_path -def _list_files(script, directory): - if directory is None: - return [script] - - basedir = directory if directory else os.path.dirname(script) - return [os.path.join(basedir, name) for name in os.listdir(basedir)] - - def download_file_from_url(url, dst, sagemaker_session): url = parse.urlparse(url) bucket, key = url.netloc, url.path.lstrip('/') From 09bd18584f2b5e076547eca4ddbd92a75440d6de Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 15 May 2019 15:40:46 -0700 Subject: [PATCH 12/12] Fix unit tests --- tests/unit/test_algorithm.py | 178 ++++++++++++++++++++--------------- 1 file changed, 100 insertions(+), 78 deletions(-) diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index 62ddf05f74..c00254eae0 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -154,7 +154,8 @@ } -def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_supported_input_mode_with_valid_input_types(session): # verify that the Estimator verifies the # input mode that an Algorithm supports. @@ -178,7 +179,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) # Creating a File mode Estimator with a File mode algorithm should work AlgorithmEstimator( @@ -186,7 +187,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) pipe_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) @@ -209,7 +210,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) # Creating a Pipe mode Estimator with a Pipe mode algorithm should work. AlgorithmEstimator( @@ -218,7 +219,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session train_instance_type='ml.m4.xlarge', train_instance_count=1, input_mode='Pipe', - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) any_input_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) @@ -241,7 +242,7 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=any_input_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=any_input_algo) # Creating a File mode Estimator with an algorithm that supports both input modes # should work. @@ -250,11 +251,12 @@ def test_algorithm_supported_input_mode_with_valid_input_types(sagemaker_session role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_supported_input_mode_with_bad_input_types(session): # verify that the Estimator verifies raises exceptions when # attempting to train with an incorrect input type @@ -278,7 +280,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=file_mode_algo) # Creating a Pipe mode Estimator with a File mode algorithm should fail. with pytest.raises(ValueError): @@ -288,7 +290,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): train_instance_type='ml.m4.xlarge', train_instance_count=1, input_mode='Pipe', - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) pipe_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) @@ -311,7 +313,7 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=pipe_mode_algo) # Creating a File mode Estimator with a Pipe mode algorithm should fail. with pytest.raises(ValueError): @@ -320,12 +322,13 @@ def test_algorithm_supported_input_mode_with_bad_input_types(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) @patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -def test_algorithm_trainining_channels_with_expected_channels(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_trainining_channels_with_expected_channels(session): training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels['TrainingSpecification']['TrainingChannels'] = [ @@ -347,14 +350,14 @@ def test_algorithm_trainining_channels_with_expected_channels(sagemaker_session) }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) + session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # Pass training and validation channels. This should work @@ -365,7 +368,8 @@ def test_algorithm_trainining_channels_with_expected_channels(sagemaker_session) @patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -def test_algorithm_trainining_channels_with_invalid_channels(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_trainining_channels_with_invalid_channels(session): training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels['TrainingSpecification']['TrainingChannels'] = [ @@ -387,14 +391,14 @@ def test_algorithm_trainining_channels_with_invalid_channels(sagemaker_session): }, ] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) + session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # Passing only validation should fail as training is required. @@ -406,7 +410,8 @@ def test_algorithm_trainining_channels_with_invalid_channels(sagemaker_session): estimator.fit({'training': 's3://some/data', 'training2': 's3://some/other/data'}) -def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_train_instance_types_valid_instance_types(session): describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) train_instance_types = ['ml.m4.xlarge', 'ml.m5.2xlarge'] @@ -414,7 +419,7 @@ def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): 'SupportedTrainingInstanceTypes' ] = train_instance_types - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=describe_algo_response ) @@ -423,7 +428,7 @@ def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) AlgorithmEstimator( @@ -431,11 +436,12 @@ def test_algorithm_train_instance_types_valid_instance_types(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m5.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_train_instance_types_invalid_instance_types(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_train_instance_types_invalid_instance_types(session): describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) train_instance_types = ['ml.m4.xlarge', 'ml.m5.2xlarge'] @@ -443,7 +449,7 @@ def test_algorithm_train_instance_types_invalid_instance_types(sagemaker_session 'SupportedTrainingInstanceTypes' ] = train_instance_types - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=describe_algo_response ) @@ -454,18 +460,19 @@ def test_algorithm_train_instance_types_invalid_instance_types(sagemaker_session role='SageMakerRole', train_instance_type='ml.m4.8xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_distributed_training_validation(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_distributed_training_validation(session): distributed_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) distributed_algo['TrainingSpecification']['SupportsDistributedTraining'] = True single_instance_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) single_instance_algo['TrainingSpecification']['SupportsDistributedTraining'] = False - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=distributed_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=distributed_algo) # Distributed training should work for Distributed and Single instance. AlgorithmEstimator( @@ -473,7 +480,7 @@ def test_algorithm_distributed_training_validation(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) AlgorithmEstimator( @@ -481,10 +488,10 @@ def test_algorithm_distributed_training_validation(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=2, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=single_instance_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=single_instance_algo) # distributed training on a single instance algorithm should fail. with pytest.raises(ValueError): @@ -493,11 +500,12 @@ def test_algorithm_distributed_training_validation(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m5.2xlarge', train_instance_count=2, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) -def test_algorithm_hyperparameter_integer_range_valid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_integer_range_valid_range(session): hyperparameters = [ { 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', @@ -515,21 +523,22 @@ def test_algorithm_hyperparameter_integer_range_valid_range(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) estimator.set_hyperparameters(max_leaf_nodes=1) estimator.set_hyperparameters(max_leaf_nodes=100000) -def test_algorithm_hyperparameter_integer_range_invalid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_integer_range_invalid_range(session): hyperparameters = [ { 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', @@ -547,14 +556,14 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(sagemaker_session) some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) with pytest.raises(ValueError): @@ -564,7 +573,8 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(sagemaker_session) estimator.set_hyperparameters(max_leaf_nodes=100001) -def test_algorithm_hyperparameter_continuous_range_valid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_continuous_range_valid_range(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -582,14 +592,14 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(sagemaker_session some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) estimator.set_hyperparameters(max_leaf_nodes=0) @@ -598,7 +608,8 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(sagemaker_session estimator.set_hyperparameters(max_leaf_nodes=1) -def test_algorithm_hyperparameter_continuous_range_invalid_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_continuous_range_invalid_range(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -616,14 +627,14 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(sagemaker_sessi some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) with pytest.raises(ValueError): @@ -633,7 +644,8 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(sagemaker_sessi estimator.set_hyperparameters(max_leaf_nodes=-0.1) -def test_algorithm_hyperparameter_categorical_range(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_hyperparameter_categorical_range(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -649,14 +661,14 @@ def test_algorithm_hyperparameter_categorical_range(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) estimator.set_hyperparameters(hp1='MXNet') @@ -669,7 +681,8 @@ def test_algorithm_hyperparameter_categorical_range(sagemaker_session): estimator.set_hyperparameters(hp1='MxNET') -def test_algorithm_required_hyperparameters_not_provided(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_required_hyperparameters_not_provided(session): hyperparameters = [ { 'Description': 'A continuous hyperparameter', @@ -691,14 +704,14 @@ def test_algorithm_required_hyperparameters_not_provided(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # hp1 is required and was not provided @@ -711,8 +724,9 @@ def test_algorithm_required_hyperparameters_not_provided(sagemaker_session): estimator.fit({'training': 's3://some/place'}) +@patch('sagemaker.Session') @patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -def test_algorithm_required_hyperparameters_are_provided(sagemaker_session): +def test_algorithm_required_hyperparameters_are_provided(session): hyperparameters = [ { 'Description': 'A categorical hyperparameter', @@ -741,21 +755,22 @@ def test_algorithm_required_hyperparameters_are_provided(sagemaker_session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # All 3 Hyperparameters are provided estimator.set_hyperparameters(hp1='TF', hp2='TF2', free_text_hp1='Hello!') -def test_algorithm_required_free_text_hyperparameter_not_provided(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_required_free_text_hyperparameter_not_provided(session): hyperparameters = [ { 'Name': 'free_text_hp1', @@ -776,14 +791,14 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(sagemaker_sess some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) # Calling fit with unset required hyperparameters should fail @@ -796,9 +811,10 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(sagemaker_sess estimator.set_hyperparameters(free_text_hp2='some text') +@patch('sagemaker.Session') @patch('sagemaker.algorithm.AlgorithmEstimator.create_model') -def test_algorithm_create_transformer(create_model, sagemaker_session): - sagemaker_session.sagemaker_client.describe_algorithm = Mock( +def test_algorithm_create_transformer(create_model, session): + session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -806,10 +822,10 @@ def test_algorithm_create_transformer(create_model, sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) - estimator.latest_training_job = _TrainingJob(sagemaker_session, 'some-job-name') + estimator.latest_training_job = _TrainingJob(session, 'some-job-name') model = Mock() model.name = 'my-model' create_model.return_value = model @@ -821,8 +837,9 @@ def test_algorithm_create_transformer(create_model, sagemaker_session): assert transformer.model_name == 'my-model' -def test_algorithm_create_transformer_without_completed_training_job(sagemaker_session): - sagemaker_session.sagemaker_client.describe_algorithm = Mock( +@patch('sagemaker.Session') +def test_algorithm_create_transformer_without_completed_training_job(session): + session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -830,7 +847,7 @@ def test_algorithm_create_transformer_without_completed_training_job(sagemaker_s role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) with pytest.raises(RuntimeError) as error: @@ -839,10 +856,11 @@ def test_algorithm_create_transformer_without_completed_training_job(sagemaker_s @patch('sagemaker.algorithm.AlgorithmEstimator.create_model') -def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_create_transformer_with_product_id(create_model, session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['ProductId'] = 'some-product-id' - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( @@ -850,10 +868,10 @@ def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_se role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) - estimator.latest_training_job = _TrainingJob(sagemaker_session, 'some-job-name') + estimator.latest_training_job = _TrainingJob(session, 'some-job-name') model = Mock() model.name = 'my-model' create_model.return_value = model @@ -862,8 +880,9 @@ def test_algorithm_create_transformer_with_product_id(create_model, sagemaker_se assert transformer.env is None -def test_algorithm_enable_network_isolation_no_product_id(sagemaker_session): - sagemaker_session.sagemaker_client.describe_algorithm = Mock( +@patch('sagemaker.Session') +def test_algorithm_enable_network_isolation_no_product_id(session): + session.sagemaker_client.describe_algorithm = Mock( return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -871,17 +890,18 @@ def test_algorithm_enable_network_isolation_no_product_id(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) network_isolation = estimator.enable_network_isolation() assert network_isolation is False -def test_algorithm_enable_network_isolation_with_product_id(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_enable_network_isolation_with_product_id(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['ProductId'] = 'some-product-id' - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( @@ -889,17 +909,18 @@ def test_algorithm_enable_network_isolation_with_product_id(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, ) network_isolation = estimator.enable_network_isolation() assert network_isolation is True -def test_algorithm_encrypt_inter_container_traffic(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_encrypt_inter_container_traffic(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response['encrypt_inter_container_traffic'] = True - sagemaker_session.sagemaker_client.describe_algorithm = Mock( + session.sagemaker_client.describe_algorithm = Mock( return_value=response) estimator = AlgorithmEstimator( @@ -907,7 +928,7 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, encrypt_inter_container_traffic=True ) @@ -915,11 +936,12 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session): assert encrypt_inter_container_traffic is True -def test_algorithm_no_required_hyperparameters(sagemaker_session): +@patch('sagemaker.Session') +def test_algorithm_no_required_hyperparameters(session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) del some_algo['TrainingSpecification']['SupportedHyperParameters'] - sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) + session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) # Calling AlgorithmEstimator() with unset required hyperparameters # should fail if they are required. @@ -929,5 +951,5 @@ def test_algorithm_no_required_hyperparameters(sagemaker_session): role='SageMakerRole', train_instance_type='ml.m4.2xlarge', train_instance_count=1, - sagemaker_session=sagemaker_session, + sagemaker_session=session, )