From 6fed597c0ddc2eb96fbf2a0e922d9f68c8035e8a Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 28 May 2019 11:55:46 -0700 Subject: [PATCH 01/11] feature: repack_model support dependencies --- src/sagemaker/tensorflow/serving.py | 16 +- src/sagemaker/utils.py | 107 +++++++---- tests/unit/test_mxnet.py | 41 ++++ tests/unit/test_utils.py | 278 ++++++++++++++-------------- 4 files changed, 260 insertions(+), 182 deletions(-) diff --git a/src/sagemaker/tensorflow/serving.py b/src/sagemaker/tensorflow/serving.py index a680f2df30..7a37318d10 100644 --- a/src/sagemaker/tensorflow/serving.py +++ b/src/sagemaker/tensorflow/serving.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import logging +import os import sagemaker from sagemaker.content_types import CONTENT_TYPE_JSON @@ -128,10 +129,17 @@ def prepare_container_def(self, instance_type, accelerator_type=None): env = self._get_container_env() if self.entry_point: - model_data = sagemaker.utils.repack_model(self.entry_point, - self.source_dir, - self.model_data, - self.sagemaker_session) + key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image) + + bucket = self.bucket or self.sagemaker_session.default_bucket() + model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') + + sagemaker.utils.repack_model(self.entry_point, + self.source_dir, + self.dependencies, + self.model_data, + model_data, + self.sagemaker_session) else: model_data = self.model_data diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d20f3194e0..ff81c9c35d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,8 +29,6 @@ import six -import sagemaker - ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' @@ -300,7 +298,12 @@ def _tmpdir(suffix='', prefix='tmp'): shutil.rmtree(tmp) -def repack_model(inference_script, source_directory, model_uri, sagemaker_session): +def repack_model(inference_script, + source_directory, + dependencies, + model_uri, + repacked_model_uri, + sagemaker_session): """Unpack model tarball and creates a new model tarball with the provided code script. This function does the following: @@ -311,60 +314,90 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio Args: inference_script (str): path or basename of the inference script that will be packed into the model source_directory (str): path including all the files that will be packed into the model + dependencies (list[str]): A list of paths to directories (absolute or relative) with + any additional libraries that will be exported to the container (default: []). + The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. + Example: + + The following call + >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env']) + results in the following inside the container: + + >>> $ ls + + >>> opt/ml/code + >>> |------ train.py + >>> |------ common + >>> |------ virtual-env + + repacked_model_uri (str): path or file system location where the new model will be saved model_uri (str): S3 or file system location of the original model tar sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. Returns: str: path to the new packed model """ - new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() + dependencies = dependencies or [] with _tmpdir() as tmp: - tmp_model_dir = os.path.join(tmp, 'model') - os.mkdir(tmp_model_dir) + model_dir = _extract_model(model_uri, sagemaker_session, tmp) - model_from_s3 = model_uri.lower().startswith('s3://') - if model_from_s3: - local_model_path = os.path.join(tmp, 'tar_file') - download_file_from_url(model_uri, local_model_path, sagemaker_session) + _update_code(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp) - new_model_path = os.path.join(tmp, new_model_name) - else: - local_model_path = model_uri.replace('file://', '') - new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name) + tmp_model_path = os.path.join(tmp, 'temp-model.tar.gz') + with tarfile.open(tmp_model_path, mode='w:gz') as t: + t.add(model_dir, arcname=os.path.sep) - with tarfile.open(name=local_model_path, mode='r:gz') as t: - t.extractall(path=tmp_model_dir) + _save_model(repacked_model_uri, tmp_model_path, sagemaker_session) - code_dir = os.path.join(tmp_model_dir, 'code') - if os.path.exists(code_dir): - shutil.rmtree(code_dir, ignore_errors=True) - if source_directory and source_directory.lower().startswith('s3://'): - local_code_path = os.path.join(tmp, 'local_code.tar.gz') - download_file_from_url(source_directory, local_code_path, sagemaker_session) +def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session): + if repacked_model_uri.lower().startswith('s3://'): + url = parse.urlparse(repacked_model_uri) + bucket, key = url.netloc, url.path.lstrip('/') + new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri)) + + sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file( + tmp_model_path) + else: + shutil.move(tmp_model_path, repacked_model_uri.replace('file://', '')) + + +def _update_code(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp): + code_dir = os.path.join(model_dir, 'code') + if os.path.exists(code_dir): + shutil.rmtree(code_dir, ignore_errors=True) + if source_directory and source_directory.lower().startswith('s3://'): + local_code_path = os.path.join(tmp, 'local_code.tar.gz') + download_file_from_url(source_directory, local_code_path, sagemaker_session) with tarfile.open(name=local_code_path, mode='r:gz') as t: t.extractall(path=code_dir) - elif source_directory: - shutil.copytree(source_directory, code_dir) - else: - os.mkdir(code_dir) - shutil.copy2(inference_script, code_dir) + elif source_directory: + shutil.copytree(source_directory, code_dir) + else: + os.mkdir(code_dir) + shutil.copy2(inference_script, code_dir) - with tarfile.open(new_model_path, mode='w:gz') as t: - t.add(tmp_model_dir, arcname=os.path.sep) + for dependency in dependencies: + if os.path.isdir(dependency): + shutil.copytree(dependency, code_dir) + else: + shutil.copy2(dependency, code_dir) - if model_from_s3: - url = parse.urlparse(model_uri) - bucket, key = url.netloc, url.path.lstrip('/') - new_key = key.replace(os.path.basename(key), new_model_name) - sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path) - return 's3://%s/%s' % (bucket, new_key) - else: - return 'file://%s' % new_model_path +def _extract_model(model_uri, sagemaker_session, tmp): + tmp_model_dir = os.path.join(tmp, 'model') + os.mkdir(tmp_model_dir) + if model_uri.lower().startswith('s3://'): + local_model_path = os.path.join(tmp, 'tar_file') + download_file_from_url(model_uri, local_model_path, sagemaker_session) + else: + local_model_path = model_uri.replace('file://', '') + with tarfile.open(name=local_model_path, mode='r:gz') as t: + t.extractall(path=tmp_model_dir) + return tmp_model_dir def download_file_from_url(url, dst, sagemaker_session): diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index dfb298d47b..e736bf0b64 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -280,6 +280,47 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): assert isinstance(predictor, MXNetPredictor) +@patch('sagemaker.utils.repack_model') +@patch('time.strftime', return_value=TIMESTAMP) +def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version): + mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + framework_version=mxnet_version) + + inputs = 's3://mybucket/train' + + mx.fit(inputs=inputs) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ['train', 'logs_for_job'] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ['resource'] + + expected_train_args = _create_train_job(mxnet_version) + expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + + actual_train_args = sagemaker_session.method_calls[0][2] + assert actual_train_args == expected_train_args + + model = mx.create_model() + + expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, 'gpu') + environment = { + 'Environment': { + 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz', + 'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', + 'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20' + }, + 'Image': expected_image_base.format(mxnet_version), + 'ModelDataUrl': 's3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz' + } + assert environment == model.prepare_container_def(GPU) + + assert 'cpu' in model.prepare_container_def(CPU)['Image'] + predictor = mx.deploy(1, GPU) + assert isinstance(predictor, MXNetPredictor) + + @patch('sagemaker.utils.repack_model', return_value=REPACKED_MODEL_DATA) @patch('time.strftime', return_value=TIMESTAMP) def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 94511939e3..2f2f706ff0 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -21,11 +21,13 @@ import re import time +from boto3 import exceptions import pytest from mock import call, patch, Mock, MagicMock import sagemaker +BUCKET_WITHOUT_WRITING_PERMISSION = 's3://bucket-without-writing-permission' NAME = 'base_name' BUCKET_NAME = 'some_bucket' @@ -300,207 +302,201 @@ def test_create_tar_file_with_auto_generated_path(open): assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] -def write_file(path, content): - with open(path, 'a') as f: - f.write(content) +def create_file_tree(root, tree): + for file in tree: + try: + os.makedirs(os.path.join(root, os.path.dirname(file))) + except: # noqa: E722 Using bare except because p2/3 incompatibility issues. + pass + with open(os.path.join(root, file), 'a') as f: + f.write(file) -def test_repack_model_without_source_dir(tmpdir): +@pytest.fixture() +def tmp(tmpdir): + yield str(tmpdir) - tmp = str(tmpdir) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') +def test_repack_model_without_source_dir(tmp, fake_s3): - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') + create_file_tree(tmp, ['model-dir/model', + 'dependencies/a', + 'dependencies/b', + 'source-dir/inference.py', + 'source-dir/this-file-should-not-be-included.py']) - script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py') - write_file(script_path, 'This file should not be included') + fake_s3.tar_and_upload('model-dir', 's3://fake/location') - contents = [model_path] + sagemaker.utils.repack_model(inference_script=os.path.join(tmp, 'source-dir/inference.py'), + source_directory=None, + dependencies=[os.path.join(tmp, 'dependencies/a'), + os.path.join(tmp, 'dependencies/b')], + model_uri='s3://fake/location', + repacked_model_uri='s3://destination-bucket/model.tar.gz', + sagemaker_session=fake_s3.sagemaker_session) - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - - model_uri = 's3://fake/location' - - new_model_uri = sagemaker.utils.repack_model(os.path.join(source_dir, 'inference.py'), - None, - model_uri, - sagemaker_session) - - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) - - -def test_repack_model_with_entry_point_without_path_without_source_dir(tmpdir): + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/model', '/code/a', + '/code/b', '/code/inference.py'} - tmp = str(tmpdir) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') +def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake_s3): - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') + create_file_tree(tmp, ['model-dir/model', + 'source-dir/inference.py', + 'source-dir/this-file-should-not-be-included.py']) - script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py') - write_file(script_path, 'This file should not be included') - - contents = [model_path] - - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - - model_uri = 's3://fake/location' + fake_s3.tar_and_upload('model-dir', 's3://fake/location') cwd = os.getcwd() try: - os.chdir(source_dir) - - new_model_uri = sagemaker.utils.repack_model('inference.py', - None, - model_uri, - sagemaker_session) + os.chdir(os.path.join(tmp, 'source-dir')) + + sagemaker.utils.repack_model('inference.py', + None, + None, + 's3://fake/location', + 's3://destination-bucket/model.tar.gz', + fake_s3.sagemaker_session) finally: os.chdir(cwd) - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) - + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/inference.py', '/model'} -def test_repack_model_from_s3_saved_model_to_s3(tmpdir): - tmp = str(tmpdir) +def test_repack_model_from_s3_to_s3(tmp, fake_s3): - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') + create_file_tree(tmp, ['model-dir/model', + 'source-dir/inference.py', + 'source-dir/this-file-should-be-included.py']) - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') + fake_s3.tar_and_upload('model-dir', 's3://fake/location') - script_path = os.path.join(source_dir, 'this-file-should-be-included.py') - write_file(script_path, 'This file should be included') + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + None, + 's3://fake/location', + 's3://destination-bucket/model.tar.gz', + fake_s3.sagemaker_session) - contents = [model_path] + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/this-file-should-be-included.py', + '/code/inference.py', + '/model'} - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - - model_uri = 's3://fake/location' - new_model_uri = sagemaker.utils.repack_model('inference.py', - source_dir, - model_uri, - sagemaker_session) +def test_repack_model_from_file_to_file(tmp): + create_file_tree(tmp, ['model', + 'dependencies/a', + 'source-dir/inference.py']) - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/this-file-should-be-included.py', - '/code/inference.py', - '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file([os.path.join(tmp, 'model')], model_tar_path) + sagemaker_session = MagicMock() -def test_repack_model_from_file_saves_model_to_file(tmpdir): + file_mode_path = 'file://%s' % model_tar_path + destination_path = 'file://%s' % os.path.join(tmp, 'repacked-model.tar.gz') - tmp = str(tmpdir) + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + [os.path.join(tmp, 'dependencies/a')], + file_mode_path, + destination_path, + sagemaker_session) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') + assert list_tar_files(destination_path, tmp) == {'/code/a', '/code/inference.py', '/model'} - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'inference.py') - write_file(script_path, 'inference script') - model_tar_path = os.path.join(tmp, 'model.tar.gz') - sagemaker.utils.create_tar_file([model_path], model_tar_path) +def test_repack_model_with_inference_code_should_replace_the_code(tmp, fake_s3): + create_file_tree(tmp, ['model-dir/model', + 'source-dir/new-inference.py', + 'model-dir/code/old-inference.py']) - sagemaker_session = MagicMock() + fake_s3.tar_and_upload('model-dir', 's3://fake/location') - file_mode_path = 'file://%s' % model_tar_path - new_model_uri = sagemaker.utils.repack_model('inference.py', - source_dir, - file_mode_path, - sagemaker_session) + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + None, + 's3://fake/location', + 's3://destination-bucket/repacked-model', + fake_s3.sagemaker_session) - assert os.path.dirname(new_model_uri) == os.path.dirname(file_mode_path) - assert list_tar_files(new_model_uri, tmpdir) == {'/code/inference.py', '/model'} + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/new-inference.py', '/model'} -def test_repack_model_with_inference_code_should_replace_the_code(tmpdir): +def test_repack_model_from_file_to_folder(tmp): + create_file_tree(tmp, ['model', + 'source-dir/inference.py']) - tmp = str(tmpdir) + model_tar_path = os.path.join(tmp, 'model.tar.gz') + sagemaker.utils.create_tar_file([os.path.join(tmp, 'model')], model_tar_path) - model_path = os.path.join(tmp, 'model') - write_file(model_path, 'model data') + file_mode_path = 'file://%s' % model_tar_path - source_dir = os.path.join(tmp, 'source-dir') - os.mkdir(source_dir) - script_path = os.path.join(source_dir, 'new-inference.py') - write_file(script_path, 'inference script') + sagemaker.utils.repack_model('inference.py', + os.path.join(tmp, 'source-dir'), + [], + file_mode_path, + 'file://%s/repacked-model.tar.gz' % tmp, + MagicMock()) - old_code_path = os.path.join(tmp, 'code') - os.mkdir(old_code_path) - old_script_path = os.path.join(old_code_path, 'old-inference.py') - write_file(old_script_path, 'old inference script') - contents = [model_path, old_code_path] + assert list_tar_files('file://%s/repacked-model.tar.gz' % tmp, tmp) == {'/code/inference.py', '/model'} - sagemaker_session = MagicMock() - mock_s3_model_tar(contents, sagemaker_session, tmp) - fake_upload_path = mock_s3_upload(sagemaker_session, tmp) - model_uri = 's3://fake/location' +class FakeS3(object): - new_model_uri = sagemaker.utils.repack_model('inference.py', - source_dir, - model_uri, - sagemaker_session) + def __init__(self, tmp): + self.tmp = tmp + self.sagemaker_session = MagicMock() + self.location_map = {} + self.current_bucket = None - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/new-inference.py', '/model'} - assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = self.download_file + self.sagemaker_session.boto_session.resource().Bucket.side_effect = self.bucket + self.fake_upload_path = self.mock_s3_upload() + def bucket(self, name): + self.current_bucket = name + return self -def mock_s3_model_tar(contents, sagemaker_session, tmp): - model_tar_path = os.path.join(tmp, 'model.tar.gz') - sagemaker.utils.create_tar_file(contents, model_tar_path) - mock_s3_download(sagemaker_session, model_tar_path) + def download_file(self, path, target): + key = '%s/%s' % (self.current_bucket, path) + shutil.copy2(self.location_map[key], target) + def tar_and_upload(self, path, fake_location): + tar_location = os.path.join(self.tmp, 'model-%s.tar.gz' % time.time()) + with tarfile.open(tar_location, mode='w:gz') as t: + t.add(os.path.join(self.tmp, path), arcname=os.path.sep) -def mock_s3_download(sagemaker_session, model_tar_path): - def download_file(_, target): - shutil.copy2(model_tar_path, target) + self.location_map[fake_location.replace('s3://', '')] = tar_location + return tar_location - sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = download_file + def mock_s3_upload(self): + dst = os.path.join(self.tmp, 'dst') + class MockS3Object(object): -def mock_s3_upload(sagemaker_session, tmp): - dst = os.path.join(tmp, 'dst') + def __init__(self, bucket, key): + self.bucket = bucket + self.key = key - class MockS3Object(object): + def upload_file(self, target): + if self.bucket in BUCKET_WITHOUT_WRITING_PERMISSION: + raise exceptions.S3UploadFailedError() + shutil.copy2(target, dst) - def __init__(self, bucket, key): - self.bucket = bucket - self.key = key + self.sagemaker_session.boto_session.resource().Object = MockS3Object + return dst - def upload_file(self, target): - shutil.copy2(target, dst) - sagemaker_session.boto_session.resource().Object = MockS3Object - return dst +@pytest.fixture() +def fake_s3(tmp): + return FakeS3(tmp) -def list_tar_files(tar_ball, tmpdir): +def list_tar_files(tar_ball, tmp): tar_ball = tar_ball.replace('file://', '') - startpath = str(tmpdir.ensure('tmp', dir=True)) + startpath = os.path.join(tmp, 'startpath') + os.mkdir(startpath) with tarfile.open(name=tar_ball, mode='r:gz') as t: t.extractall(path=startpath) From 286128533bc1f021baa69511b8e50fd010f9baf6 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 28 May 2019 12:03:02 -0700 Subject: [PATCH 02/11] feature: repack_model support dependencies --- src/sagemaker/model.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index bbb78bdfc8..6cb625e51e 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -435,21 +435,25 @@ def _upload_code(self, key_prefix, repack=False): local_code = utils.get_config_value('local.local_code', self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: self.uploaded_code = None - else: - if not repack: - bucket = self.bucket or self.sagemaker_session.default_bucket() - self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session, - bucket=bucket, - s3_key_prefix=key_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies) + elif not repack: + bucket = self.bucket or self.sagemaker_session.default_bucket() + self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session, + bucket=bucket, + s3_key_prefix=key_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies) if repack: - self.repacked_model_data = utils.repack_model(inference_script=self.entry_point, - source_directory=self.source_dir, - model_uri=self.model_data, - sagemaker_session=self.sagemaker_session) + bucket = self.bucket or self.sagemaker_session.default_bucket() + self.repacked_model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') + + utils.repack_model(inference_script=self.entry_point, + source_directory=self.source_dir, + dependencies=self.dependencies, + model_uri=self.model_data, + repacked_model_uri=self.repacked_model_data, + sagemaker_session=self.sagemaker_session) self.uploaded_code = UploadedCode(s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point)) From 53fa83a29baed20896568804c7c034c049c6323b Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 28 May 2019 12:05:40 -0700 Subject: [PATCH 03/11] feature: repack_model support dependencies --- tests/unit/test_mxnet.py | 51 +++++----------------------------------- 1 file changed, 6 insertions(+), 45 deletions(-) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index e736bf0b64..0f45afe51f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -29,7 +29,6 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') MODEL_DATA = 's3://mybucket/model' -REPACKED_MODEL_DATA = 's3://mybucket/repacked/model' TIMESTAMP = '2017-11-06-14:14:15.672' TIME = 1507167947 BUCKET_NAME = 'mybucket' @@ -321,46 +320,6 @@ def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_vers assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.repack_model', return_value=REPACKED_MODEL_DATA) -@patch('time.strftime', return_value=TIMESTAMP) -def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version) - - inputs = 's3://mybucket/train' - - mx.fit(inputs=inputs) - - sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] - boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] - - expected_train_args = _create_train_job(mxnet_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs - - actual_train_args = sagemaker_session.method_calls[0][2] - assert actual_train_args == expected_train_args - - model = mx.create_model() - - expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, 'gpu') - environment = { - 'Environment': { - 'SAGEMAKER_SUBMIT_DIRECTORY': REPACKED_MODEL_DATA, - 'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20' - }, - 'Image': expected_image_base.format(mxnet_version), 'ModelDataUrl': REPACKED_MODEL_DATA - } - assert environment == model.prepare_container_def(GPU) - - assert 'cpu' in model.prepare_container_def(CPU)['Image'] - predictor = mx.deploy(1, GPU) - assert isinstance(predictor, MXNetPredictor) - - @patch('sagemaker.utils.create_tar_file', MagicMock()) @patch('time.strftime', return_value=TIMESTAMP) def test_mxnet_neo(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): @@ -407,21 +366,23 @@ def test_model(sagemaker_session): assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.repack_model', return_value=REPACKED_MODEL_DATA) +@patch('sagemaker.utils.repack_model') def test_model_mms_version(repack_model, sagemaker_session): model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, framework_version=MXNetModel._LOWEST_MMS_VERSION, - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, name='test-mxnet-model') predictor = model.deploy(1, GPU) repack_model.assert_called_once_with(inference_script=SCRIPT_PATH, source_directory=None, + dependencies=[], model_uri=MODEL_DATA, + repacked_model_uri='s3://mybucket/test-mxnet-model/model.tar.gz', sagemaker_session=sagemaker_session) assert model.model_data == MODEL_DATA - assert model.repacked_model_data == REPACKED_MODEL_DATA - assert model.uploaded_code == UploadedCode(s3_prefix=REPACKED_MODEL_DATA, + assert model.repacked_model_data == 's3://mybucket/test-mxnet-model/model.tar.gz' + assert model.uploaded_code == UploadedCode(s3_prefix='s3://mybucket/test-mxnet-model/model.tar.gz', script_name=os.path.basename(SCRIPT_PATH)) assert isinstance(predictor, MXNetPredictor) From aa4056faa32b7cd0f1a10d2b8975d8d5c667dfaf Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 28 May 2019 12:12:03 -0700 Subject: [PATCH 04/11] feature: repack_model support dependencies --- src/sagemaker/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index ff81c9c35d..35a405cf62 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -371,8 +371,8 @@ def _update_code(model_dir, inference_script, source_directory, dependencies, sa local_code_path = os.path.join(tmp, 'local_code.tar.gz') download_file_from_url(source_directory, local_code_path, sagemaker_session) - with tarfile.open(name=local_code_path, mode='r:gz') as t: - t.extractall(path=code_dir) + with tarfile.open(name=local_code_path, mode='r:gz') as t: + t.extractall(path=code_dir) elif source_directory: shutil.copytree(source_directory, code_dir) From 0c4face33ba95bc96bd787f018d24cc03e62e16a Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 28 May 2019 16:18:31 -0700 Subject: [PATCH 05/11] feature: repack_model support dependencies --- src/sagemaker/model.py | 4 +++- src/sagemaker/utils.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 6cb625e51e..7c39f1fda2 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -446,7 +446,7 @@ def _upload_code(self, key_prefix, repack=False): if repack: bucket = self.bucket or self.sagemaker_session.default_bucket() - self.repacked_model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') + repacked_model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') utils.repack_model(inference_script=self.entry_point, source_directory=self.source_dir, @@ -454,6 +454,8 @@ def _upload_code(self, key_prefix, repack=False): model_uri=self.model_data, repacked_model_uri=self.repacked_model_data, sagemaker_session=self.sagemaker_session) + + self.repacked_model_data = repacked_model_data self.uploaded_code = UploadedCode(s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point)) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 35a405cf62..9d1d139cb3 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -342,7 +342,7 @@ def repack_model(inference_script, with _tmpdir() as tmp: model_dir = _extract_model(model_uri, sagemaker_session, tmp) - _update_code(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp) + _create_or_update_code_dir(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp) tmp_model_path = os.path.join(tmp, 'temp-model.tar.gz') with tarfile.open(tmp_model_path, mode='w:gz') as t: @@ -363,7 +363,8 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session): shutil.move(tmp_model_path, repacked_model_uri.replace('file://', '')) -def _update_code(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp): +def _create_or_update_code_dir(model_dir, inference_script, source_directory, + dependencies, sagemaker_session, tmp): code_dir = os.path.join(model_dir, 'code') if os.path.exists(code_dir): shutil.rmtree(code_dir, ignore_errors=True) From f2713d7e0bedca1fa70010af73de63ba25c5766d Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Tue, 28 May 2019 16:22:27 -0700 Subject: [PATCH 06/11] feature: repack_model support dependencies --- src/sagemaker/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 802875f888..3d518aa299 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -465,7 +465,7 @@ def _upload_code(self, key_prefix, repack=False): source_directory=self.source_dir, dependencies=self.dependencies, model_uri=self.model_data, - repacked_model_uri=self.repacked_model_data, + repacked_model_uri=repacked_model_data, sagemaker_session=self.sagemaker_session) self.repacked_model_data = repacked_model_data From 6ed1f5a1929f25c83b730726b1d9d1e865ec41f7 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 29 May 2019 10:07:25 -0700 Subject: [PATCH 07/11] feature: repack_model support dependencies --- src/sagemaker/mxnet/model.py | 6 +- .../code/inference.py | 1 + tests/integ/test_tfs.py | 4 ++ tests/unit/test_tfs.py | 56 +++++++++++++++++++ 4 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 40eea52bf5..29f5040c5d 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -92,21 +92,21 @@ def prepare_container_def(self, instance_type, accelerator_type=None): Returns: dict[str, str]: A container definition object usable with the CreateModel API. """ - mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION) + is_mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION) deploy_image = self.image if not deploy_image: region_name = self.sagemaker_session.boto_session.region_name framework_name = self.__framework_name__ - if mms_version: + if is_mms_version: framework_name += '-serving' deploy_image = create_image_uri(region_name, framework_name, instance_type, self.framework_version, self.py_version, accelerator_type=accelerator_type) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) - self._upload_code(deploy_key_prefix, mms_version) + self._upload_code(deploy_key_prefix, is_mms_version) deploy_env = dict(self.env) deploy_env.update(self._framework_env_vars()) diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py index 507d0c44f3..d371cc7a16 100644 --- a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. import json +import dependency def input_handler(data, context): data = json.loads(data.read().decode('utf-8')) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 05e0725d5c..e189ba7452 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -97,9 +97,13 @@ def tfs_predictor_with_model_and_entry_point_separated(instance_type, sagemaker_session): entry_point = os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference/code/inference.py') + dependencies = [os.path.join(tests.integ.DATA_DIR, + 'tfs/tfs-test-model-with-inference/dependency.py')] + model = Model(entry_point=entry_point, model_data=model_data, role='SageMakerRole', + dependencies= dependencies, framework_version=tf_full_version, sagemaker_session=sagemaker_session) predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) diff --git a/tests/unit/test_tfs.py b/tests/unit/test_tfs.py index 5bcdbfba8b..d2d59e0c2d 100644 --- a/tests/unit/test_tfs.py +++ b/tests/unit/test_tfs.py @@ -15,6 +15,8 @@ import io import json import logging + +import mock import pytest from mock import Mock from sagemaker.tensorflow import TensorFlow @@ -102,6 +104,60 @@ def test_tfs_model_with_custom_image(sagemaker_session, tf_version): assert cdef['Image'] == 'my-image' +@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') +@mock.patch('sagemaker.utils.repack_model') +def test_tfs_model_with_entry_point(repack_model, model_code_key_prefix, sagemaker_session, + tf_version): + model = Model("s3://some/data.tar.gz", + entry_point='train.py', + role=ROLE, framework_version=tf_version, + image='my-image', sagemaker_session=sagemaker_session) + + model.prepare_container_def(INSTANCE_TYPE) + + model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) + + repack_model.assert_called_with('train.py', None, [], 's3://some/data.tar.gz', + 's3://my_bucket/key-prefix/model.tar.gz', + sagemaker_session) + + +@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') +@mock.patch('sagemaker.utils.repack_model') +def test_tfs_model_with_source(repack_model, model_code_key_prefix, sagemaker_session, tf_version): + model = Model("s3://some/data.tar.gz", + entry_point='train.py', + source_dir='src', + role=ROLE, framework_version=tf_version, + image='my-image', sagemaker_session=sagemaker_session) + + model.prepare_container_def(INSTANCE_TYPE) + + model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) + + repack_model.assert_called_with('train.py', 'src', [], 's3://some/data.tar.gz', + 's3://my_bucket/key-prefix/model.tar.gz', + sagemaker_session) + + +@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') +@mock.patch('sagemaker.utils.repack_model') +def test_tfs_model_with_dependencies(repack_model, model_code_key_prefix, sagemaker_session, tf_version): + model = Model("s3://some/data.tar.gz", + entry_point='train.py', + dependencies=['src', 'lib'], + role=ROLE, framework_version=tf_version, + image='my-image', sagemaker_session=sagemaker_session) + + model.prepare_container_def(INSTANCE_TYPE) + + model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) + + repack_model.assert_called_with('train.py', None, ['src', 'lib'], 's3://some/data.tar.gz', + 's3://my_bucket/key-prefix/model.tar.gz', + sagemaker_session) + + def test_estimator_deploy(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source' From 0e83796b22ccb605f494239d21429b28b74aed55 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 29 May 2019 10:07:47 -0700 Subject: [PATCH 08/11] feature: repack_model support dependencies --- tests/data/tfs/tfs-test-model-with-inference/dependency.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/data/tfs/tfs-test-model-with-inference/dependency.py diff --git a/tests/data/tfs/tfs-test-model-with-inference/dependency.py b/tests/data/tfs/tfs-test-model-with-inference/dependency.py new file mode 100644 index 0000000000..e69de29bb2 From 89b72a5d5344e383388ae6c9782918302696cf8c Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 29 May 2019 10:11:44 -0700 Subject: [PATCH 09/11] feature: repack_model support dependencies --- tests/integ/test_tfs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index e189ba7452..4a70084be6 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -103,7 +103,7 @@ def tfs_predictor_with_model_and_entry_point_separated(instance_type, model = Model(entry_point=entry_point, model_data=model_data, role='SageMakerRole', - dependencies= dependencies, + dependencies=dependencies, framework_version=tf_full_version, sagemaker_session=sagemaker_session) predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) From 797bfd601dadb5eaf21b0b4685b8985c1e53a816 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Wed, 29 May 2019 14:31:41 -0700 Subject: [PATCH 10/11] feature: repack_model support dependencies --- .../tfs-test-model-with-inference/code/inference.py | 2 +- .../tfs/tfs-test-model-with-inference/dependency.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py index d371cc7a16..2fe2eb3327 100644 --- a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -1,4 +1,4 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of diff --git a/tests/data/tfs/tfs-test-model-with-inference/dependency.py b/tests/data/tfs/tfs-test-model-with-inference/dependency.py index e69de29bb2..c60b935b80 100644 --- a/tests/data/tfs/tfs-test-model-with-inference/dependency.py +++ b/tests/data/tfs/tfs-test-model-with-inference/dependency.py @@ -0,0 +1,12 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. From 79efec2c690b4dc8b3ab50cef28a2ece90ad4579 Mon Sep 17 00:00:00 2001 From: Marcio Dos Santos Date: Thu, 30 May 2019 08:35:49 -0700 Subject: [PATCH 11/11] feature: repack_model support dependencies --- .../dependency.py | 0 .../inference.py | 27 +++++++++++++++++++ .../code/inference.py | 2 -- tests/integ/test_tfs.py | 14 +++++----- 4 files changed, 34 insertions(+), 9 deletions(-) rename tests/data/tfs/{tfs-test-model-with-inference => tfs-test-entrypoint-and-dependencies}/dependency.py (100%) create mode 100644 tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py diff --git a/tests/data/tfs/tfs-test-model-with-inference/dependency.py b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/dependency.py similarity index 100% rename from tests/data/tfs/tfs-test-model-with-inference/dependency.py rename to tests/data/tfs/tfs-test-entrypoint-and-dependencies/dependency.py diff --git a/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py new file mode 100644 index 0000000000..2fe2eb3327 --- /dev/null +++ b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py @@ -0,0 +1,27 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json + +import dependency + +def input_handler(data, context): + data = json.loads(data.read().decode('utf-8')) + new_values = [x + 1 for x in data['instances']] + dumps = json.dumps({'instances': new_values}) + return dumps + + +def output_handler(data, context): + response_content_type = context.accept_header + prediction = data.content + return prediction, response_content_type diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py index 2fe2eb3327..2f691fea1d 100644 --- a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -12,8 +12,6 @@ # language governing permissions and limitations under the License. import json -import dependency - def input_handler(data, context): data = json.loads(data.read().decode('utf-8')) new_values = [x + 1 for x in data['instances']] diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 4a70084be6..ab43b1368c 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -84,8 +84,8 @@ def tfs_predictor_with_model_and_entry_point_same_tar(instance_type, @pytest.fixture(scope='module') -def tfs_predictor_with_model_and_entry_point_separated(instance_type, - sagemaker_session, tf_full_version): +def tfs_predictor_with_model_and_entry_point_and_dependencies(instance_type, + sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') model_data = sagemaker_session.upload_data( @@ -96,9 +96,9 @@ def tfs_predictor_with_model_and_entry_point_separated(instance_type, with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): entry_point = os.path.join(tests.integ.DATA_DIR, - 'tfs/tfs-test-model-with-inference/code/inference.py') + 'tfs/tfs-test-entrypoint-and-dependencies/inference.py') dependencies = [os.path.join(tests.integ.DATA_DIR, - 'tfs/tfs-test-model-with-inference/dependency.py')] + 'tfs/tfs-test-entrypoint-and-dependencies/dependency.py')] model = Model(entry_point=entry_point, model_data=model_data, @@ -156,12 +156,12 @@ def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_ assert expected_result == result -def test_predict_with_model_and_entry_point_separated( - tfs_predictor_with_model_and_entry_point_separated): +def test_predict_with_model_and_entry_point_and_dependencies_separated( + tfs_predictor_with_model_and_entry_point_and_dependencies): input_data = {'instances': [1.0, 2.0, 5.0]} expected_result = {'predictions': [4.0, 4.5, 6.0]} - result = tfs_predictor_with_model_and_entry_point_separated.predict(input_data) + result = tfs_predictor_with_model_and_entry_point_and_dependencies.predict(input_data) assert expected_result == result