@@ -159,6 +159,7 @@ def test_mnist_async(sagemaker_session):
159159 training_job_name = estimator .latest_training_job .name
160160 time .sleep (20 )
161161 endpoint_name = training_job_name
162+ model_name = 'model-name-1'
162163 _assert_training_job_tags_match (
163164 sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
164165 )
@@ -167,7 +168,8 @@ def test_mnist_async(sagemaker_session):
167168 training_job_name = training_job_name , sagemaker_session = sagemaker_session
168169 )
169170 predictor = estimator .deploy (
170- initial_instance_count = 1 , instance_type = "ml.c4.xlarge" , endpoint_name = endpoint_name
171+ initial_instance_count = 1 , instance_type = "ml.c4.xlarge" , endpoint_name = endpoint_name ,
172+ model_name = model_name
171173 )
172174
173175 result = predictor .predict (np .zeros (784 ))
@@ -176,6 +178,9 @@ def test_mnist_async(sagemaker_session):
176178 _assert_model_tags_match (
177179 sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
178180 )
181+ _assert_model_name_match (
182+ sagemaker_session .sagemaker_client , endpoint_name , model_name
183+ )
179184
180185
181186def test_deploy_with_input_handlers (sagemaker_session , instance_type ):
@@ -241,3 +246,10 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
241246 TrainingJobName = training_job_name
242247 )
243248 _assert_tags_match (sagemaker_client , training_job_description ["TrainingJobArn" ], tags )
249+
250+
251+ def _assert_model_name_match (sagemaker_client , endpoint_config_name , model_name ):
252+ endpoint_config_description = sagemaker_client .describe_endpoint_config (
253+ EndpointConfigName = endpoint_config_name
254+ )
255+ assert model_name == endpoint_config_description ['ProductionVariants' ][0 ]['ModelName' ]
0 commit comments