diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fada8fd8ef..a7757cad0f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,8 @@ CHANGELOG ========== * doc-fix: Remove incorrect parameter for EI TFS Python README +* feature: ``Predictor``: delete SageMaker model +* feature: ``Pipeline``: delete SageMaker model 1.18.3.post1 ============ diff --git a/README.rst b/README.rst index a7ca3d748b..682fc60a4f 100644 --- a/README.rst +++ b/README.rst @@ -192,6 +192,8 @@ Here is an end to end example of how to use a SageMaker Estimator: # Tears down the SageMaker endpoint and endpoint configuration mxnet_predictor.delete_endpoint() + # Deletes the SageMaker model + mxnet_predictor.delete_model() The example above will eventually delete both the SageMaker endpoint and endpoint configuration through `delete_endpoint()`. If you want to keep your SageMaker endpoint configuration, use the value False for the `delete_endpoint_config` parameter, as shown below. @@ -230,6 +232,9 @@ For more `information `__ , and use e # Tears down the endpoint container and deletes the corresponding endpoint configuration mxnet_predictor.delete_endpoint() + # Deletes the model + mxnet_predictor.delete_model() + If you have an existing model and want to deploy it locally, don't specify a sagemaker_session argument to the ``MXNetModel`` constructor. The correct session is generated when you call ``model.deploy()``. @@ -307,6 +315,9 @@ Here is an end-to-end example: # Tear down the endpoint container and delete the corresponding endpoint configuration predictor.delete_endpoint() + # Deletes the model + predictor.delete_model() + If you don't want to deploy your model locally, you can also choose to perform a Local Batch Transform Job. This is useful if you want to test your container before creating a Sagemaker Batch Transform Job. Note that the performance diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index af379a6b3d..4d9d3cd19e 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -103,3 +103,14 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags) if self.predictor_cls: return self.predictor_cls(self.endpoint_name, self.sagemaker_session) + + def delete_model(self): + """Delete the SageMaker model backing this pipeline model. This does not delete the list of SageMaker models used + in multiple containers to build the inference pipeline. + + """ + + if self.name is None: + raise ValueError('The SageMaker model must be created before attempting to delete.') + + self.sagemaker_session.delete_model(self.name) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 958b21afe6..5da69dfb2f 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -56,6 +56,8 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ self.deserializer = deserializer self.content_type = content_type or getattr(serializer, 'content_type', None) self.accept = accept or getattr(deserializer, 'accept', None) + self._endpoint_config_name = self._get_endpoint_config_name() + self._model_names = self._get_model_names() def predict(self, data, initial_args=None): """Return the inference from the specified endpoint. @@ -109,16 +111,16 @@ def _delete_endpoint_config(self): """Delete the Amazon SageMaker endpoint configuration """ - endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint) - endpoint_config_name = endpoint_description['EndpointConfigName'] - self.sagemaker_session.delete_endpoint_config(endpoint_config_name) + self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name) def delete_endpoint(self, delete_endpoint_config=True): - """Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor. + """Delete the Amazon SageMaker endpoint backing this predictor. Also delete the endpoint configuration attached + to it if delete_endpoint_config is True. Args: - delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint - configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True) + delete_endpoint_config (bool, optional): Flag to indicate whether to delete endpoint configuration together + with endpoint. Defaults to True. If True, both endpoint and endpoint configuration will be deleted. If + False, only endpoint will be deleted. """ if delete_endpoint_config: @@ -126,6 +128,34 @@ def delete_endpoint(self, delete_endpoint_config=True): self.sagemaker_session.delete_endpoint(self.endpoint) + def delete_model(self): + """Deletes the Amazon SageMaker models backing this predictor. + + """ + request_failed = False + failed_models = [] + for model_name in self._model_names: + try: + self.sagemaker_session.delete_model(model_name) + except Exception: # pylint: disable=broad-except + request_failed = True + failed_models.append(model_name) + + if request_failed: + raise Exception('One or more models cannot be deleted, please retry. \n' + 'Failed models: {}'.format(', '.join(failed_models))) + + def _get_endpoint_config_name(self): + endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint) + endpoint_config_name = endpoint_desc['EndpointConfigName'] + return endpoint_config_name + + def _get_model_names(self): + endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config( + EndpointConfigName=self._endpoint_config_name) + production_variants = endpoint_config['ProductionVariants'] + return map(lambda d: d['ModelName'], production_variants) + class _CsvSerializer(object): def __init__(self): diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index d3c597b13e..1fbe45c618 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -92,3 +92,8 @@ def test_inference_pipeline_model_deploy(sagemaker_session): invalid_data = "1.0,28.0,C,38.0,71.5,1.0" assert (predictor.predict(invalid_data) is None) + + model.delete_model() + with pytest.raises(Exception) as exception: + sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) + assert 'Could not find model' in str(exception.value) diff --git a/tests/integ/test_kmeans.py b/tests/integ/test_kmeans.py index 85f6b247e8..37e18fbfd7 100644 --- a/tests/integ/test_kmeans.py +++ b/tests/integ/test_kmeans.py @@ -75,6 +75,11 @@ def test_kmeans(sagemaker_session): assert record.label["closest_cluster"] is not None assert record.label["distance_to_cluster"] is not None + predictor.delete_model() + with pytest.raises(Exception) as exception: + sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) + assert 'Could not find model' in str(exception.value) + def test_async_kmeans(sagemaker_session): training_job_name = "" diff --git a/tests/integ/test_mxnet_train.py b/tests/integ/test_mxnet_train.py index 3d67f3b4c2..ed57789933 100644 --- a/tests/integ/test_mxnet_train.py +++ b/tests/integ/test_mxnet_train.py @@ -72,6 +72,11 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version) data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) + predictor.delete_model() + with pytest.raises(Exception) as exception: + sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) + assert 'Could not find model' in str(exception.value) + def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version): endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index fbc439ecde..aa712038da 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -45,6 +45,15 @@ GPU = 'ml.p2.xlarge' CPU = 'ml.c4.xlarge' +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -54,6 +63,8 @@ def sagemaker_session(): describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 04afe65b5c..59d803254e 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -102,6 +102,15 @@ 'ModelDataUrl': MODEL_DATA, } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + class DummyFramework(Framework): __framework_name__ = 'dummy' @@ -146,6 +155,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index 9dbb7c014c..d420301378 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -37,6 +37,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -47,6 +56,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index e64adbd1be..402383348b 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -39,6 +39,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -49,6 +58,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index a5b829fa60..dc87a38ec1 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -36,6 +36,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -46,6 +55,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 890f6ec9fc..fad5ae64b1 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -40,6 +40,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -50,6 +59,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index 5c0528e322..eed9902292 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -35,6 +35,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -44,6 +53,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index 8819717957..1e5fa64f0f 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -37,6 +37,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -47,6 +56,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 0136053c13..18302fdb20 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -45,6 +45,15 @@ CPU_C5 = 'ml.c5.xlarge' LAUNCH_PS_DISTRIBUTIONS_DICT = {'parameter_server': {'enabled': True}} +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -55,6 +64,8 @@ def sagemaker_session(): describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} describe_compilation = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/model_c5.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.wait_for_compilation_job = Mock(return_value=describe_compilation) session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 1f3866968b..c72e8cea70 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -36,6 +36,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -46,6 +55,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index 4fc6130722..8d8e557014 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -45,6 +45,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -55,6 +64,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index c72cd4fedd..0748d90ca6 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -36,6 +36,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -46,6 +55,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index e45c34409e..d640aa3235 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -138,3 +138,22 @@ def test_deploy_tags(tfo, time, sagemaker_session): 'InitialInstanceCount': 1, 'VariantName': 'AllTraffic'}], tags) + + +def test_delete_model_without_deploy(sagemaker_session): + pipeline_model = PipelineModel([], role=ROLE, sagemaker_session=sagemaker_session) + + expected_error_message = 'The SageMaker model must be created before attempting to delete.' + with pytest.raises(ValueError, match=expected_error_message): + pipeline_model.delete_model() + + +@patch('tarfile.open') +@patch('time.strftime', return_value=TIMESTAMP) +def test_delete_model(tfo, time, sagemaker_session): + framework_model = DummyFrameworkModel(sagemaker_session) + pipeline_model = PipelineModel([framework_model], role=ROLE, sagemaker_session=sagemaker_session) + pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) + + pipeline_model.delete_model() + sagemaker_session.delete_model.assert_called_with(pipeline_model.name) diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 55ac1f5f9a..2b1434c580 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -16,7 +16,7 @@ import json import os import pytest -from mock import Mock +from mock import Mock, call import numpy as np @@ -319,11 +319,22 @@ def test_numpy_deser_from_npy_object_array(): RETURN_VALUE = 0 CSV_RETURN_VALUE = "1,2,3\r\n" +ENDPOINT_DESC = { + 'EndpointConfigName': ENDPOINT +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + def empty_sagemaker_session(): ims = Mock(name='sagemaker_session') ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name='sagemaker_runtime') + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) response_body = Mock('body') response_body.read = Mock('read', return_value=RETURN_VALUE) @@ -378,6 +389,11 @@ def json_sagemaker_session(): ims = Mock(name='sagemaker_session') ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name='sagemaker_runtime') + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) response_body = Mock('body') response_body.read = Mock('read', return_value=json.dumps([RETURN_VALUE])) @@ -416,6 +432,11 @@ def ret_csv_sagemaker_session(): ims = Mock(name='sagemaker_session') ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name='sagemaker_runtime') + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) response_body = Mock('body') response_body.read = Mock('read', return_value=CSV_RETURN_VALUE) @@ -466,3 +487,27 @@ def test_delete_endpoint_only(): sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT) sagemaker_session.delete_endpoint_config.assert_not_called() + + +def test_delete_model(): + sagemaker_session = empty_sagemaker_session() + predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session) + + predictor.delete_model() + + expected_call_count = 2 + expected_call_args_list = [call('model-1'), call('model-2')] + assert sagemaker_session.delete_model.call_count == expected_call_count + assert sagemaker_session.delete_model.call_args_list == expected_call_args_list + + +def test_delete_model_fail(): + sagemaker_session = empty_sagemaker_session() + sagemaker_session.sagemaker_client.delete_model = Mock(side_effect=Exception('Could not find model.')) + expected_error_message = 'One or more models cannot be deleted, please retry.' + + predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session) + + with pytest.raises(Exception) as exception: + predictor.delete_model() + assert expected_error_message in str(exception.val) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 20edf62640..3a71d320b8 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -43,6 +43,15 @@ GPU = 'ml.p2.xlarge' CPU = 'ml.c4.xlarge' +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture(name='sagemaker_session') def fixture_sagemaker_session(): @@ -52,6 +61,8 @@ def fixture_sagemaker_session(): describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index d265d22ba6..561f7c29ac 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -38,6 +38,15 @@ } } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -48,6 +57,8 @@ def sagemaker_session(): sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index c9e28f3c23..a4657c901d 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -40,6 +40,15 @@ GPU = 'ml.p2.xlarge' CPU = 'ml.c4.xlarge' +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture(name='sagemaker_session') def fixture_sagemaker_session(): @@ -49,6 +58,8 @@ def fixture_sagemaker_session(): describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 5795aa3e90..9a3b199156 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -43,6 +43,15 @@ REGION = 'us-west-2' CPU = 'ml.c4.xlarge' +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -52,6 +61,8 @@ def sagemaker_session(): describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index bbd588170d..bacf64a81f 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -26,6 +26,15 @@ BUCKET_NAME = 'Some-Bucket' ENDPOINT = 'some-endpoint' +ENDPOINT_DESC = { + 'EndpointConfigName': ENDPOINT +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -33,6 +42,8 @@ def sagemaker_session(): sms = Mock(name='sagemaker_session', boto_session=boto_mock, region_name=REGION, config=None, local_mode=False) sms.boto_region_name = REGION + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index d6482ae479..e2777d0326 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -48,6 +48,15 @@ DISTRIBUTION_ENABLED = {'parameter_server': {'enabled': True}} DISTRIBUTION_MPI_ENABLED = {'mpi': {'enabled': True, 'custom_mpi_options': 'options', 'processes_per_host': 2}} +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -58,6 +67,8 @@ def sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return session diff --git a/tests/unit/test_tf_predictor.py b/tests/unit/test_tf_predictor.py index a9ff684df6..b58971adbf 100644 --- a/tests/unit/test_tf_predictor.py +++ b/tests/unit/test_tf_predictor.py @@ -43,12 +43,23 @@ JSON_CONTENT_TYPE = "application/json" PROTO_CONTENT_TYPE = "application/octet-stream" +ENDPOINT_DESC = { + 'EndpointConfigName': ENDPOINT +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name='boto_session', region_name=REGION) ims = Mock(name='sagemaker_session', boto_session=boto_mock) ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return ims diff --git a/tests/unit/test_tfs.py b/tests/unit/test_tfs.py index e29e3d4098..77f68561a2 100644 --- a/tests/unit/test_tfs.py +++ b/tests/unit/test_tfs.py @@ -40,6 +40,15 @@ } REGRESS_RESPONSE = {'results': [3.5, 4.0]} +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -50,6 +59,8 @@ def sagemaker_session(): session.expand_role = Mock(name="expand_role", return_value=ROLE) describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return session @@ -191,14 +202,14 @@ def test_predictor_regress(sagemaker_session): assert REGRESS_RESPONSE == result -def test_predictor_regress_bad_content_type(): +def test_predictor_regress_bad_content_type(sagemaker_session): predictor = Predictor('endpoint', sagemaker_session, csv_serializer) with pytest.raises(ValueError): predictor.regress(REGRESS_INPUT) -def test_predictor_classify_bad_content_type(): +def test_predictor_classify_bad_content_type(sagemaker_session): predictor = Predictor('endpoint', sagemaker_session, csv_serializer) with pytest.raises(ValueError): diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 76480d9cf5..4fec84965e 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -127,6 +127,15 @@ 'HyperParameterTuningJobArn': 'arn:tuning_job', } +ENDPOINT_DESC = { + 'EndpointConfigName': 'test-endpoint' +} + +ENDPOINT_CONFIG_DESC = { + 'ProductionVariants': [{'ModelName': 'model-1'}, + {'ModelName': 'model-2'}] +} + @pytest.fixture() def sagemaker_session(): @@ -135,6 +144,10 @@ def sagemaker_session(): sms.boto_region_name = REGION sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) sms.config = None + + sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) + sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) + return sms