@@ -159,6 +159,7 @@ def test_mnist_async(sagemaker_session):
159159 training_job_name = estimator .latest_training_job .name
160160 time .sleep (20 )
161161 endpoint_name = training_job_name
162+ model_name = "model-name-1"
162163 _assert_training_job_tags_match (
163164 sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
164165 )
@@ -167,7 +168,10 @@ def test_mnist_async(sagemaker_session):
167168 training_job_name = training_job_name , sagemaker_session = sagemaker_session
168169 )
169170 predictor = estimator .deploy (
170- initial_instance_count = 1 , instance_type = "ml.c4.xlarge" , endpoint_name = endpoint_name
171+ initial_instance_count = 1 ,
172+ instance_type = "ml.c4.xlarge" ,
173+ endpoint_name = endpoint_name ,
174+ model_name = model_name ,
171175 )
172176
173177 result = predictor .predict (np .zeros (784 ))
@@ -176,6 +180,7 @@ def test_mnist_async(sagemaker_session):
176180 _assert_model_tags_match (
177181 sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
178182 )
183+ _assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
179184
180185
181186def test_deploy_with_input_handlers (sagemaker_session , instance_type ):
@@ -241,3 +246,10 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
241246 TrainingJobName = training_job_name
242247 )
243248 _assert_tags_match (sagemaker_client , training_job_description ["TrainingJobArn" ], tags )
249+
250+
251+ def _assert_model_name_match (sagemaker_client , endpoint_config_name , model_name ):
252+ endpoint_config_description = sagemaker_client .describe_endpoint_config (
253+ EndpointConfigName = endpoint_config_name
254+ )
255+ assert model_name == endpoint_config_description ["ProductionVariants" ][0 ]["ModelName" ]
0 commit comments