diff --git a/setup.py b/setup.py index 5b6c31fd3c..3c4728c96e 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def read_version(): # Declare minimal set for installation required_packages = [ "attrs", - "boto3>=1.20.18", + "boto3>=1.20.21", "google-pasta", "numpy>=1.9.0", "protobuf>=3.1", diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index cf039fa010..1de03d6183 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -852,8 +852,8 @@ def logs(self): def deploy( self, - initial_instance_count, - instance_type, + initial_instance_count=None, + instance_type=None, serializer=None, deserializer=None, accelerator_type=None, @@ -864,6 +864,7 @@ def deploy( kms_key=None, data_capture_config=None, tags=None, + serverless_inference_config=None, **kwargs, ): """Deploy the trained model to an Amazon SageMaker endpoint. @@ -874,10 +875,14 @@ def deploy( http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html Args: - initial_instance_count (int): Minimum number of EC2 instances to - deploy to an endpoint for prediction. - instance_type (str): Type of EC2 instance to deploy to an endpoint - for prediction, for example, 'ml.c4.xlarge'. + initial_instance_count (int): The initial number of instances to run + in the ``Endpoint`` created from this ``Model``. If not using + serverless inference, then it need to be a number larger or equals + to 1 (default: None) + instance_type (str): The EC2 instance type to deploy this Model to. + For example, 'ml.p2.xlarge', or 'local' for local mode. If not using + serverless inference, then it is required to deploy a model. + (default: None) serializer (:class:`~sagemaker.serializers.BaseSerializer`): A serializer object, used to encode data for an inference endpoint (default: None). If ``serializer`` is not None, then @@ -910,6 +915,11 @@ def deploy( data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, we will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None) tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -927,6 +937,7 @@ def deploy( endpoint and obtain inferences. """ removed_kwargs("update_endpoint", kwargs) + is_serverless = serverless_inference_config is not None self._ensure_latest_training_job() self._ensure_base_job_name() default_name = name_from_base(self.base_job_name) @@ -934,7 +945,7 @@ def deploy( model_name = model_name or default_name self.deploy_instance_type = instance_type - if use_compiled_model: + if use_compiled_model and not is_serverless: family = "_".join(instance_type.split(".")[:-1]) if family not in self._compiled_models: raise ValueError( @@ -959,6 +970,7 @@ def deploy( wait=wait, kms_key=kms_key, data_capture_config=data_capture_config, + serverless_inference_config=serverless_inference_config, ) def register( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 830bb50dab..504f0f1e73 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -32,6 +32,7 @@ from sagemaker.inputs import CompilationInput from sagemaker.deprecations import removed_kwargs from sagemaker.predictor import PredictorBase +from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer LOGGER = logging.getLogger("sagemaker") @@ -209,7 +210,7 @@ def register( model_package_arn=model_package.get("ModelPackageArn"), ) - def _init_sagemaker_session_if_does_not_exist(self, instance_type): + def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already. The type of session object is determined by the instance type. @@ -688,8 +689,8 @@ def compile( def deploy( self, - initial_instance_count, - instance_type, + initial_instance_count=None, + instance_type=None, serializer=None, deserializer=None, accelerator_type=None, @@ -698,6 +699,7 @@ def deploy( kms_key=None, wait=True, data_capture_config=None, + serverless_inference_config=None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -715,9 +717,13 @@ def deploy( Args: initial_instance_count (int): The initial number of instances to run - in the ``Endpoint`` created from this ``Model``. + in the ``Endpoint`` created from this ``Model``. If not using + serverless inference, then it need to be a number larger or equals + to 1 (default: None) instance_type (str): The EC2 instance type to deploy this Model to. - For example, 'ml.p2.xlarge', or 'local' for local mode. + For example, 'ml.p2.xlarge', or 'local' for local mode. If not using + serverless inference, then it is required to deploy a model. + (default: None) serializer (:class:`~sagemaker.serializers.BaseSerializer`): A serializer object, used to encode data for an inference endpoint (default: None). If ``serializer`` is not None, then @@ -746,7 +752,17 @@ def deploy( data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. - + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, we will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None) + Raises: + ValueError: If arguments combination check failed in these circumstances: + - If no role is specified or + - If serverless inference config is not specified and instance type and instance + count are also not specified or + - If a wrong type of object is provided as serverless inference config Returns: callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` @@ -758,27 +774,47 @@ def deploy( if self.role is None: raise ValueError("Role can not be null for deploying a model") - if instance_type.startswith("ml.inf") and not self._is_compiled_model: + is_serverless = serverless_inference_config is not None + if not is_serverless and not (instance_type and initial_instance_count): + raise ValueError( + "Must specify instance type and instance count unless using serverless inference" + ) + + if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig): + raise ValueError( + "serverless_inference_config needs to be a ServerlessInferenceConfig object" + ) + + if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model: LOGGER.warning( "Your model is not compiled. Please compile your model before using Inferentia." ) - compiled_model_suffix = "-".join(instance_type.split(".")[:-1]) - if self._is_compiled_model: + compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1]) + if self._is_compiled_model and not is_serverless: self._ensure_base_name_if_needed(self.image_uri) if self._base_name is not None: self._base_name = "-".join((self._base_name, compiled_model_suffix)) self._create_sagemaker_model(instance_type, accelerator_type, tags) + + serverless_inference_config_dict = ( + serverless_inference_config._to_request_dict() if is_serverless else None + ) production_variant = sagemaker.production_variant( - self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type + self.name, + instance_type, + initial_instance_count, + accelerator_type=accelerator_type, + serverless_inference_config=serverless_inference_config_dict, ) if endpoint_name: self.endpoint_name = endpoint_name else: base_endpoint_name = self._base_name or utils.base_from_name(self.name) - if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix): - base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) + if self._is_compiled_model and not is_serverless: + if not base_endpoint_name.endswith(compiled_model_suffix): + base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) self.endpoint_name = utils.name_from_base(base_endpoint_name) data_capture_config_dict = None diff --git a/src/sagemaker/serverless/__init__.py b/src/sagemaker/serverless/__init__.py index 8bf55c0dcd..4ecffb56d8 100644 --- a/src/sagemaker/serverless/__init__.py +++ b/src/sagemaker/serverless/__init__.py @@ -13,3 +13,6 @@ """Classes for performing machine learning on serverless compute.""" from sagemaker.serverless.model import LambdaModel # noqa: F401 from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401 +from sagemaker.serverless.serverless_inference_config import ( # noqa: F401 + ServerlessInferenceConfig, +) diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py new file mode 100644 index 0000000000..39950f4f84 --- /dev/null +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code related to the ServerlessInferenceConfig class. + +Codes are used for configuring async inference endpoint. Use it when deploying +the model to the endpoints. +""" +from __future__ import print_function, absolute_import + + +class ServerlessInferenceConfig(object): + """Configuration object passed in when deploying models to Amazon SageMaker Endpoints. + + This object specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference + """ + + def __init__( + self, + memory_size_in_mb=2048, + max_concurrency=5, + ): + """Initialize a ServerlessInferenceConfig object for serverless inference configuration. + + Args: + memory_size_in_mb (int): Optional. The memory size of your serverless endpoint. + Valid values are in 1 GB increments: 1024 MB, 2048 MB, 3072 MB, 4096 MB, + 5120 MB, or 6144 MB. If no value is provided, Amazon SageMaker will choose + the default value for you. (Default: 2048) + max_concurrency (int): Optional. The maximum number of concurrent invocations + your serverless endpoint can process. If no value is provided, Amazon + SageMaker will choose the default value for you. (Default: 5) + """ + self.memory_size_in_mb = memory_size_in_mb + self.max_concurrency = max_concurrency + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = { + "MemorySizeInMB": self.memory_size_in_mb, + "MaxConcurrency": self.max_concurrency, + } + + return request_dict diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 56f008be84..1de9571ac6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4382,11 +4382,12 @@ def pipeline_container_def(models, instance_type=None): def production_variant( model_name, - instance_type, - initial_instance_count=1, + instance_type=None, + initial_instance_count=None, variant_name="AllTraffic", initial_weight=1, accelerator_type=None, + serverless_inference_config=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. @@ -4405,14 +4406,15 @@ def production_variant( accelerator_type (str): Type of Elastic Inference accelerator for this production variant. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html + serverless_inference_config (dict): Specifies configuration dict related to serverless + endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig + object (default: None) Returns: dict[str, str]: An SageMaker ``ProductionVariant`` description """ production_variant_configuration = { "ModelName": model_name, - "InstanceType": instance_type, - "InitialInstanceCount": initial_instance_count, "VariantName": variant_name, "InitialVariantWeight": initial_weight, } @@ -4420,6 +4422,13 @@ def production_variant( if accelerator_type: production_variant_configuration["AcceleratorType"] = accelerator_type + if serverless_inference_config: + production_variant_configuration["ServerlessConfig"] = serverless_inference_config + else: + initial_instance_count = initial_instance_count or 1 + production_variant_configuration["InitialInstanceCount"] = initial_instance_count + production_variant_configuration["InstanceType"] = instance_type + return production_variant_configuration diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index d4eb3e60aa..d13bdc8ffa 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -258,8 +258,8 @@ def register( def deploy( self, - initial_instance_count, - instance_type, + initial_instance_count=None, + instance_type=None, serializer=None, deserializer=None, accelerator_type=None, @@ -269,6 +269,7 @@ def deploy( wait=True, data_capture_config=None, update_endpoint=None, + serverless_inference_config=None, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -287,6 +288,7 @@ def deploy( kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, + serverless_inference_config=serverless_inference_config, update_endpoint=update_endpoint, ) diff --git a/tests/integ/test_serverless_inference.py b/tests/integ/test_serverless_inference.py new file mode 100644 index 0000000000..40b1ace147 --- /dev/null +++ b/tests/integ/test_serverless_inference.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +import sagemaker.amazon.pca +from sagemaker.utils import unique_name_from_base +from sagemaker.serverless import ServerlessInferenceConfig +from tests.integ import datasets, TRAINING_DEFAULT_TIMEOUT_MINUTES +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name + + +@pytest.fixture +def training_set(): + return datasets.one_p_mnist() + + +def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_set): + job_name = unique_name_from_base("pca") + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + pca = sagemaker.amazon.pca.PCA( + role="SageMakerRole", + instance_count=1, + instance_type=cpu_instance_type, + num_components=48, + sagemaker_session=sagemaker_session, + enable_network_isolation=True, + ) + + pca.algorithm_mode = "randomized" + pca.subtract_mean = True + pca.extra_components = 5 + pca.fit(pca.record_set(training_set[0][:100]), job_name=job_name) + + with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): + + predictor_serverless = pca.deploy( + endpoint_name=job_name, serverless_inference_config=ServerlessInferenceConfig() + ) + + result = predictor_serverless.predict(training_set[0][:5]) + + assert len(result) == 5 + for record in result: + assert record.label["projection"] is not None diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 284956aa75..03af3acb7d 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -19,6 +19,7 @@ import sagemaker from sagemaker.model import Model +from sagemaker.serverless import ServerlessInferenceConfig MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -62,7 +63,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None) production_variant.assert_called_with( - MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=None + MODEL_NAME, + INSTANCE_TYPE, + INSTANCE_COUNT, + accelerator_type=None, + serverless_inference_config=None, ) sagemaker_session.create_model.assert_called_with( @@ -101,7 +106,11 @@ def test_deploy_accelerator_type( create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None) production_variant.assert_called_with( - MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=ACCELERATOR_TYPE + MODEL_NAME, + INSTANCE_TYPE, + INSTANCE_COUNT, + accelerator_type=ACCELERATOR_TYPE, + serverless_inference_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -279,6 +288,71 @@ def test_deploy_data_capture_config(production_variant, name_from_base, sagemake ) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) +@patch("sagemaker.model.Model._create_sagemaker_model") +@patch("sagemaker.production_variant") +def test_deploy_serverless_inference(production_variant, create_sagemaker_model, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT) + production_variant.return_value = production_variant_result + + serverless_inference_config = ServerlessInferenceConfig() + serverless_inference_config_dict = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 5, + } + + model.deploy( + serverless_inference_config=serverless_inference_config, + ) + + create_sagemaker_model.assert_called_with(None, None, None) + production_variant.assert_called_with( + MODEL_NAME, + None, + None, + accelerator_type=None, + serverless_inference_config=serverless_inference_config_dict, + ) + + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=ENDPOINT_NAME, + production_variants=[production_variant_result], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + ) + + +def test_deploy_wrong_inference_type(sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) + + bad_args = ( + {"instance_type": INSTANCE_TYPE}, + {"initial_instance_count": INSTANCE_COUNT}, + {"instance_type": None, "initial_instance_count": None}, + ) + for args in bad_args: + with pytest.raises( + ValueError, + match="Must specify instance type and instance count unless using serverless inference", + ): + model.deploy(args) + + +def test_deploy_wrong_serverless_config(sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) + with pytest.raises( + ValueError, + match="serverless_inference_config needs to be a ServerlessInferenceConfig object", + ): + model.deploy(serverless_inference_config={}) + + @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): diff --git a/tests/unit/sagemaker/serverless/test_serverless_inference_config.py b/tests/unit/sagemaker/serverless/test_serverless_inference_config.py new file mode 100644 index 0000000000..fab80748a4 --- /dev/null +++ b/tests/unit/sagemaker/serverless/test_serverless_inference_config.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.serverless import ServerlessInferenceConfig + +DEFAULT_MEMORY_SIZE_IN_MB = 2048 +DEFAULT_MAX_CONCURRENCY = 5 + +DEFAULT_REQUEST_DICT = { + "MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB, + "MaxConcurrency": DEFAULT_MAX_CONCURRENCY, +} + + +def test_init(): + serverless_inference_config = ServerlessInferenceConfig() + + assert serverless_inference_config.memory_size_in_mb == DEFAULT_MEMORY_SIZE_IN_MB + assert serverless_inference_config.max_concurrency == DEFAULT_MAX_CONCURRENCY + + +def test_to_request_dict(): + serverless_inference_config_dict = ServerlessInferenceConfig()._to_request_dict() + + assert serverless_inference_config_dict == DEFAULT_REQUEST_DICT diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 248eda1aa5..5940ca2c0a 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2801,6 +2801,37 @@ def test_generic_to_deploy(time, sagemaker_session): assert predictor.sagemaker_session == sagemaker_session +def test_generic_to_deploy_bad_arguments_combination(sagemaker_session): + e = Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) + + e.fit() + + bad_args = ( + {"instance_type": INSTANCE_TYPE}, + {"initial_instance_count": INSTANCE_COUNT}, + {"instance_type": None, "initial_instance_count": None}, + ) + for args in bad_args: + with pytest.raises( + ValueError, + match="Must specify instance type and instance count unless using serverless inference", + ): + e.deploy(args) + + with pytest.raises( + ValueError, + match="serverless_inference_config needs to be a ServerlessInferenceConfig object", + ): + e.deploy(serverless_inference_config={}) + + def test_generic_to_deploy_network_isolation(sagemaker_session): e = Estimator( IMAGE_URI, @@ -2850,6 +2881,7 @@ def test_generic_to_deploy_kms(create_model, sagemaker_session): wait=True, kms_key=kms_key, data_capture_config=None, + serverless_inference_config=None, ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b2c14c5e5a..9a63a6c114 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -749,6 +749,11 @@ def test_training_input_all_arguments(): IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({"TransformJobStatus": "InProgress"}) +SERVERLESS_INFERENCE_CONFIG = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 2, +} + @pytest.fixture() def sagemaker_session(): @@ -1911,6 +1916,31 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi ) +def test_endpoint_from_production_variants_with_serverless_inference_config(sagemaker_session): + ims = sagemaker_session + ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) + pvs = [ + sagemaker.production_variant( + "A", "ml.p2.xlarge", serverless_inference_config=SERVERLESS_INFERENCE_CONFIG + ), + sagemaker.production_variant( + "B", "p299.4096xlarge", serverless_inference_config=SERVERLESS_INFERENCE_CONFIG + ), + ] + ex = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + ) + ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) + tags = [{"ModelName": "TestModel"}] + sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs, tags) + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=tags + ) + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=tags + ) + + def test_update_endpoint_succeed(sagemaker_session): sagemaker_session.sagemaker_client.describe_endpoint = Mock( return_value={"EndpointStatus": "InService"}