@@ -308,29 +308,26 @@ def test_create_model_with_optional_params(sagemaker_session):
308308 assert model .vpc_config == vpc_config
309309
310310
311- @patch ("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model " )
312- def test_transformer_creation_with_endpoint_type (create_tfs_model , sagemaker_session ):
311+ @patch ("sagemaker.tensorflow.estimator.TensorFlow.create_model " )
312+ def test_transformer_creation_with_endpoint_type (create_model , sagemaker_session ):
313313 tf = TensorFlow (
314314 entry_point = SCRIPT_PATH ,
315315 role = ROLE ,
316316 sagemaker_session = sagemaker_session ,
317317 train_instance_count = INSTANCE_COUNT ,
318318 train_instance_type = INSTANCE_TYPE ,
319319 )
320-
321320 tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
322- transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , endpoint_type = "tensorflow-serving" )
323- assert isinstance (transformer , Transformer )
324- assert transformer .sagemaker_session == sagemaker_session
325- assert transformer .instance_count == INSTANCE_COUNT
326- assert transformer .instance_type == INSTANCE_TYPE
327- assert tf .script_mode is True
328- assert tf ._script_mode_enabled () is True
329- create_tfs_model .assert_called_once ()
321+
322+ tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , model_server_workers = 2 , endpoint_type = "tensorflow-serving" )
323+ create_model .assert_called_with (endpoint_type = 'tensorflow-serving' ,
324+ model_server_workers = 2 ,
325+ role = 'Dummy' ,
326+ vpc_config_override = 'VPC_CONFIG_DEFAULT' )
330327
331328
332- @patch ("sagemaker.tensorflow.estimator.TensorFlow._create_default_model " )
333- def test_transformer_creation_without_endpoint_type (create_default_model , sagemaker_session ):
329+ @patch ("sagemaker.tensorflow.estimator.TensorFlow.create_model " )
330+ def test_transformer_creation_without_endpoint_type (create_model , sagemaker_session ):
334331
335332 tf = TensorFlow (
336333 entry_point = SCRIPT_PATH ,
@@ -341,14 +338,11 @@ def test_transformer_creation_without_endpoint_type(create_default_model, sagema
341338 )
342339
343340 tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
344- transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE )
345- assert isinstance (transformer , Transformer )
346- assert transformer .sagemaker_session == sagemaker_session
347- assert transformer .instance_count == INSTANCE_COUNT
348- assert transformer .instance_type == INSTANCE_TYPE
349- assert tf .script_mode is False
350- assert tf ._script_mode_enabled () is False
351- create_default_model .assert_called_once ()
341+ transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , model_server_workers = 4 )
342+ create_model .assert_called_with (endpoint_type = None ,
343+ model_server_workers = 4 ,
344+ role = 'Dummy' ,
345+ vpc_config_override = 'VPC_CONFIG_DEFAULT' )
352346
353347
354348def test_create_model_with_custom_image (sagemaker_session ):
0 commit comments